1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17 #include <utility>
18
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/core/common_runtime/device_set.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/common_runtime/partitioning_utils.h"
24 #include "tensorflow/core/common_runtime/placer.h"
25 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
26 #include "tensorflow/core/common_runtime/rendezvous_util.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/graph/graph_constructor.h"
34 #include "tensorflow/core/graph/graph_partition.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 #include "tensorflow/core/util/ptr_util.h"
39 #include "tensorflow/core/util/reffed_status_callback.h"
40
41 namespace tensorflow {
42
43 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
44
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)45 Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
46 DistributedFunctionLibraryRuntime* parent, const string& function_name,
47 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
48 const FunctionLibraryRuntime::InstantiateOptions& options) {
49 mutex_lock l(mu_);
50 if (!init_started_) {
51 init_started_ = true;
52 init_result_ = parent->Instantiate(function_name, lib_def, attrs, options,
53 &local_handle_);
54 }
55 return init_result_;
56 }
57
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent)58 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
59 const DeviceMgr* device_mgr, Env* env, int graph_def_version,
60 const FunctionLibraryDefinition* lib_def,
61 const OptimizerOptions& optimizer_options,
62 thread::ThreadPool* default_thread_pool,
63 DistributedFunctionLibraryRuntime* parent)
64 : env_(env),
65 device_mgr_(device_mgr),
66 lib_def_(lib_def),
67 default_thread_pool_(default_thread_pool),
68 next_handle_(0),
69 parent_(parent) {
70 if (device_mgr == nullptr) {
71 flr_map_[nullptr] = NewFunctionLibraryRuntime(
72 nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
73 optimizer_options, this);
74 return;
75 }
76 for (Device* d : device_mgr->ListDevices()) {
77 flr_map_[d] = NewFunctionLibraryRuntime(
78 device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
79 optimizer_options, this);
80 }
81 }
82
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent)83 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
84 const DeviceMgr* device_mgr, Env* env, int graph_def_version,
85 const FunctionLibraryDefinition* lib_def,
86 const OptimizerOptions& optimizer_options,
87 CustomKernelCreator custom_kernel_creator,
88 thread::ThreadPool* default_thread_pool,
89 DistributedFunctionLibraryRuntime* parent)
90 : env_(env),
91 device_mgr_(device_mgr),
92 lib_def_(lib_def),
93 default_thread_pool_(default_thread_pool),
94 next_handle_(0),
95 parent_(parent) {
96 if (device_mgr == nullptr) {
97 flr_map_[nullptr] = NewFunctionLibraryRuntime(
98 nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
99 optimizer_options, std::move(custom_kernel_creator), this);
100 return;
101 }
102 for (Device* d : device_mgr->ListDevices()) {
103 flr_map_[d] = NewFunctionLibraryRuntime(
104 device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
105 optimizer_options, custom_kernel_creator, this);
106 }
107 }
108
109 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous)110 Status ProcessFunctionLibraryRuntime::SendTensors(
111 const string& source_device, const string& target_device,
112 const string& key_prefix, int64 src_incarnation,
113 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
114 const std::vector<AllocatorAttributes>& alloc_attrs,
115 Rendezvous* rendezvous) {
116 std::vector<string> keys;
117 for (int i = 0; i < tensors_to_send.size(); ++i) {
118 string name = strings::StrCat(key_prefix, i);
119 string key = Rendezvous::CreateKey(source_device, src_incarnation,
120 target_device, name, FrameAndIter(0, 0));
121 keys.push_back(key);
122 }
123 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
124 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
125 return Status::OK();
126 }
127
128 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,int64 num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)129 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
130 const string& source_device, const string& target_device,
131 const string& key_prefix, int64 src_incarnation, int64 num_tensors,
132 DeviceContext* device_context,
133 const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
134 std::vector<Tensor>* received_tensors, StatusCallback done) {
135 std::vector<string> keys;
136 for (int64 i = 0; i < num_tensors; ++i) {
137 string name = strings::StrCat(key_prefix, i);
138 string key = Rendezvous::CreateKey(source_device, src_incarnation,
139 target_device, name, FrameAndIter(0, 0));
140 keys.push_back(key);
141 }
142 RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
143 received_tensors, std::move(done));
144 }
145
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const146 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
147 const string& device_name, int64* incarnation) const {
148 FunctionLibraryRuntime* flr = GetFLR(device_name);
149 if (flr == nullptr) {
150 return errors::InvalidArgument("Device name: ", device_name, " not found");
151 }
152 *incarnation = flr->device()->attributes().incarnation();
153 return Status::OK();
154 }
155
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const156 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
157 const string& device_name, DeviceContext** device_context) const {
158 *device_context = nullptr;
159 FunctionLibraryRuntime* flr = GetFLR(device_name);
160 if (flr == nullptr) {
161 return errors::InvalidArgument("Device name: ", device_name, " not found.");
162 }
163 Device* device = flr->device();
164 string device_type = device->parsed_name().type;
165 if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
166 // "TPU_SYSTEM" indicates that `device` is a CPU.
167 return Status::OK();
168 }
169 if (device_type == "GPU" || device_type == "TPU") {
170 auto* dev_info = flr->device()->tensorflow_gpu_device_info();
171 if (dev_info) {
172 *device_context = dev_info->default_context;
173 return Status::OK();
174 }
175 }
176 return errors::Internal("Device type: ", device_type,
177 " is currently unsupported for remote ",
178 "function executions");
179 }
180
GetFLR(const string & device_name) const181 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
182 const string& device_name) const {
183 Device* device = nullptr;
184 if (device_name != kDefaultFLRDevice) {
185 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
186 VLOG(1) << "Could not find device: " << device_name;
187 return nullptr;
188 }
189 }
190 const auto& iter = flr_map_.find(device);
191 if (iter == flr_map_.end()) {
192 LOG(ERROR) << "Could not find device: " << device_name;
193 return nullptr;
194 }
195 return iter->second.get();
196 }
197
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)198 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
199 const string& function_key, const string& device_name,
200 FunctionLibraryRuntime::LocalHandle local_handle) {
201 mutex_lock l(mu_);
202 return AddHandleLocked(function_key, device_name, local_handle);
203 }
204
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)205 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
206 const string& function_key, const string& device_name,
207 FunctionLibraryRuntime::LocalHandle local_handle) {
208 auto h = next_handle_;
209 function_data_[h] =
210 MakeUnique<FunctionData>(device_name, local_handle, function_key);
211 table_[function_key] = h;
212 next_handle_++;
213 return h;
214 }
215
216 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)217 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
218 std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
219 mutex_lock l(mu_);
220 auto h = next_handle_;
221 mdevice_data_[h] = std::move(data);
222 table_[function_key] = h;
223 next_handle_++;
224 return h;
225 }
226
GetHandle(const string & function_key) const227 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
228 const string& function_key) const {
229 tf_shared_lock l(mu_);
230 return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
231 }
232
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const233 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
234 const string& device_name, FunctionLibraryRuntime::Handle handle) const {
235 return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
236 }
237
238 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const239 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
240 const string& device_name, FunctionLibraryRuntime::Handle handle) const {
241 tf_shared_lock l(mu_);
242
243 auto miter = mdevice_data_.find(handle);
244 if (miter != mdevice_data_.end()) {
245 return kInvalidLocalHandle;
246 }
247
248 auto iter = function_data_.find(handle);
249 if (iter == function_data_.end()) {
250 return kInvalidLocalHandle;
251 }
252 FunctionData* function_data = iter->second.get();
253 if (function_data->target_device() != device_name) {
254 return kInvalidLocalHandle;
255 }
256 return function_data->local_handle();
257 }
258
GetDeviceName(FunctionLibraryRuntime::Handle handle) const259 string ProcessFunctionLibraryRuntime::GetDeviceName(
260 FunctionLibraryRuntime::Handle handle) const {
261 tf_shared_lock l(mu_);
262 auto iter = function_data_.find(handle);
263 CHECK(iter != function_data_.end());
264 FunctionData* function_data = iter->second.get();
265 return function_data->target_device();
266 }
267
268 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const269 ProcessFunctionLibraryRuntime::IsMultiDevice(
270 FunctionLibraryRuntime::Handle handle) const {
271 tf_shared_lock l(mu_);
272 const auto& it = mdevice_data_.find(handle);
273 if (it != mdevice_data_.end()) {
274 return it->second.get();
275 }
276 return nullptr;
277 }
278
279 namespace {
280 // Sets `group` to the first colocation group specified in `node`. If no
281 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)282 void GetColocationGroup(const Node* node, string* group) {
283 // We hoist the conversion from C-style string literal to string here,
284 // so that we can avoid the many repeated calls to strlen().
285 static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
286 const AttrValue* attr_value =
287 node->attrs().Find(kColocationAttrNameStringPiece);
288 if (attr_value != nullptr && attr_value->has_list() &&
289 attr_value->list().s_size() > 0) {
290 *group = attr_value->list().s(0);
291 }
292 }
293
AssignedOrRequestedDeviceName(const Node & node)294 const string* AssignedOrRequestedDeviceName(const Node& node) {
295 if (node.has_assigned_device_name()) {
296 return &node.assigned_device_name();
297 }
298 return &node.requested_device();
299 }
300
301 } // anonymous namespace
302
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,Graph * graph) const303 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
304 const std::vector<string>& input_devices,
305 const std::vector<string>& output_devices, const DeviceSet& device_set,
306 Graph* graph) const {
307 // If output_devices are not specified, we want to set the output device
308 // based on the device of the output producing node. The output producing
309 // node can be an arg node because functions can simply return their
310 // arguments. To make sure that the output producing nodes have assigned
311 // devices, we assign them to arguments first.
312 for (Node* node : graph->op_nodes()) {
313 if (node->IsArg()) {
314 const AttrValue* attr_value;
315 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
316 int64 index = attr_value->i();
317 node->set_assigned_device_name(input_devices[index]);
318 }
319 }
320
321 for (Node* node : graph->op_nodes()) {
322 if (node->IsRetval()) {
323 if (output_devices.empty()) {
324 VLOG(3) << "Trying to determine device for node " << node->name();
325 // If output_devices are empty, the node producing retval
326 // must have explicitly assigned device or a colocation constraint
327 // to a node with explicitly assigned device.
328 for (const auto& it : node->in_edges()) {
329 if (!it->IsControlEdge()) {
330 Node* src_node = it->src();
331 const string* src_device = AssignedOrRequestedDeviceName(*src_node);
332 string colocation_group = "";
333 GetColocationGroup(src_node, &colocation_group);
334 VLOG(3) << "Considering src: " << src_node->name()
335 << " src_device: " << *src_device
336 << " colo group: " << colocation_group;
337 while (src_device->empty() && colocation_group.empty() &&
338 src_node->IsIdentity()) {
339 src_node = *src_node->in_nodes().begin();
340 src_device = AssignedOrRequestedDeviceName(*src_node);
341 GetColocationGroup(src_node, &colocation_group);
342 VLOG(3) << "Considering src: " << src_node->name()
343 << " src_device: " << *src_device
344 << " colo group: " << colocation_group;
345 }
346
347 if (!colocation_group.empty()) {
348 AttrValue::ListValue colo_attr;
349 colo_attr.add_s(colocation_group);
350 std::vector<string> colo_slice = {colocation_group};
351 node->AddAttr(kColocationAttrName, colo_slice);
352 } else if (!src_device->empty()) {
353 // src_device can be a partially specified device. Find the
354 // matching device in the device_set.
355 DeviceNameUtils::ParsedName parsed;
356 if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
357 return errors::InvalidArgument(
358 "Failed to parse explicit device specification ",
359 *src_device);
360 }
361 std::vector<Device*> matching_devices;
362 device_set.FindMatchingDevices(parsed, &matching_devices);
363 if (matching_devices.empty()) {
364 return errors::InvalidArgument(
365 "Unable to find any devices for spec ", *src_device);
366 } else if (matching_devices.size() != 1) {
367 // Convert a vector of devices to a string.
368 // Using absl::StrJoin did not work in Android builds.
369 string devices = "[";
370 for (Device* device : matching_devices) {
371 devices.append(device->name());
372 devices.append(", ");
373 }
374 if (devices.size() > 2) {
375 devices.resize(devices.size() - 2);
376 }
377 devices.append("]");
378
379 return errors::InvalidArgument(
380 "When FunctionLibraryRuntime::Options.output_devices are "
381 "not specified for a multi-device function, the device "
382 "specification on the output node must match exactly one "
383 "device. Matched devices are ",
384 devices);
385 }
386 VLOG(3) << "Setting output device to "
387 << matching_devices[0]->name() << " for node "
388 << node->DebugString();
389 node->set_assigned_device_name(matching_devices[0]->name());
390 }
391 }
392 }
393 } else {
394 const AttrValue* attr_value;
395 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
396 int64 index = attr_value->i();
397 // output_devices size is checked in InstantiateMultiDevice
398 DCHECK_GT(output_devices.size(), index);
399 VLOG(3) << "Setting output device to " << output_devices[index]
400 << " for return at index " << index;
401 node->set_assigned_device_name(output_devices[index]);
402 }
403 }
404 }
405 return Status::OK();
406 }
407
408 namespace {
409
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)410 Status ValidateNoListArguments(
411 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
412 const string& function_name) {
413 for (const OpDef::ArgDef& arg : args) {
414 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
415 return errors::InvalidArgument(
416 "Function ", function_name, " has an ", arg_type, " named \"",
417 arg.name(),
418 "\" that is a list of tensors."
419 " Multi-device functions support only single-tensor inputs "
420 " and outputs");
421 }
422 }
423 return Status::OK();
424 }
425
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)426 Status ValidateMultiDeviceOptions(
427 const FunctionDef& fdef,
428 const FunctionLibraryRuntime::InstantiateOptions& options) {
429 const OpDef& signature = fdef.signature();
430 // Multi-device functions don't currently support list inputs or outputs
431 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
432 signature.name()));
433 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
434 signature.name()));
435
436 if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
437 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
438 return errors::Unimplemented(
439 "Function '", signature.name(), "' has `",
440 FunctionLibraryDefinition::kIntsOnDeviceAttr,
441 "` attribute set. This attribute is not currently supported by "
442 "multi-device functions.");
443 }
444
445 if (options.input_devices.size() != signature.input_arg_size()) {
446 return errors::InvalidArgument(
447 "InstantiateOptions.input_devices must have the same length "
448 "as the number of arguments: input_devices length = ",
449 options.input_devices.size(),
450 " number of arguments = ", signature.input_arg_size());
451 }
452 if (!options.output_devices.empty() &&
453 options.output_devices.size() != signature.output_arg_size()) {
454 return errors::InvalidArgument(
455 "InstantiateOptions.output_devices must either be empty or have "
456 "the same length as the number of arguments: output_devices length "
457 "= ",
458 options.output_devices.size(),
459 " number of arguments = ", signature.output_arg_size());
460 }
461
462 if (!options.state_handle.empty()) {
463 return errors::Unimplemented(
464 "InstantiateOptions.state_handle is not supported for multi-device "
465 "functions. Function: ",
466 signature.name());
467 }
468 if (options.create_kernels_eagerly) {
469 return errors::Unimplemented(
470 "InstantiateOptions.create_kernels_eagerly is not supported for "
471 "multi-device functions. Function: ",
472 signature.name());
473 }
474
475 return Status::OK();
476 }
477
GetGraphAndRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<string> * ret_node_names,std::vector<string> * control_ret_node_names)478 Status GetGraphAndRets(const string& function_name, AttrSlice attrs,
479 const FunctionDef* fdef,
480 const FunctionLibraryDefinition* lib_def,
481 std::unique_ptr<Graph>* graph,
482 std::vector<string>* ret_node_names,
483 std::vector<string>* control_ret_node_names) {
484 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
485 return lib_def->LookUpOpDef(op, sig);
486 };
487 FunctionBody* tmp_fbody;
488 // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
489 TF_RETURN_IF_ERROR(
490 FunctionDefToBodyHelper(*fdef, attrs, lib_def, get_func_sig, &tmp_fbody));
491 if (tmp_fbody == nullptr) {
492 LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
493 return errors::Internal("Failed to construct FunctionBody for ",
494 function_name);
495 }
496 std::unique_ptr<FunctionBody> fbody(tmp_fbody);
497 *graph = std::unique_ptr<Graph>(fbody->graph);
498 fbody->graph = nullptr;
499 ret_node_names->reserve(fbody->ret_nodes.size());
500 for (const Node* node : fbody->ret_nodes) {
501 ret_node_names->push_back(node->name());
502 }
503 control_ret_node_names->reserve(fbody->control_ret_nodes.size());
504 for (const Node* node : fbody->control_ret_nodes) {
505 control_ret_node_names->push_back(node->name());
506 }
507 return Status::OK();
508 }
509
510 } // anonymous namespace
511
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)512 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
513 const string& function_name, AttrSlice attrs,
514 const FunctionLibraryRuntime::InstantiateOptions& options,
515 FunctionLibraryRuntime::Handle* handle) {
516 // Check if this function has already been instantiated.
517 const string& function_key = Canonicalize(function_name, attrs, options);
518
519 {
520 mutex_lock l(mu_);
521 const auto& it = table_.find(function_key);
522 if (it != table_.end()) {
523 *handle = it->second;
524 ++mdevice_data_[*handle]->instantiation_counter_;
525 return Status::OK();
526 }
527 }
528
529 VLOG(1) << "Instantiating MultiDevice function \"" << function_name
530 << "\" on default device \"" << options.target << "\"";
531 if (VLOG_IS_ON(3)) {
532 VLOG(3) << "Requested input devices:";
533 for (const string& device : options.input_devices) {
534 VLOG(3) << " " << device;
535 }
536 VLOG(3) << "Requested output devices:";
537 for (const string& device : options.output_devices) {
538 VLOG(3) << " " << device;
539 }
540 }
541
542 const FunctionLibraryDefinition* lib_def =
543 options.overlay_lib == nullptr ? lib_def_ : options.overlay_lib;
544
545 const FunctionDef* fdef = lib_def->Find(function_name);
546 if (fdef == nullptr) {
547 return errors::InvalidArgument("Failed to find function \"", function_name,
548 "\" in function library: ", lib_def);
549 }
550
551 TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
552
553 std::unique_ptr<Graph> graph;
554 std::vector<string> ret_node_names;
555 std::vector<string> control_ret_node_names;
556
557 TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def,
558 &graph, &ret_node_names,
559 &control_ret_node_names));
560
561 if (options.graph_collector != nullptr) {
562 GraphDef def;
563 graph->ToGraphDef(&def);
564 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
565 options.graph_collector->CollectRawGraph(def);
566 }
567
568 DeviceSet device_set;
569 for (auto d : device_mgr_->ListDevices()) {
570 device_set.AddDevice(d);
571 }
572
573 TF_RETURN_IF_ERROR(PinArgsAndRets(
574 options.input_devices, options.output_devices, device_set, graph.get()));
575
576 std::unique_ptr<MultiDeviceFunctionData> data =
577 MakeUnique<MultiDeviceFunctionData>(function_name, function_key,
578 ret_node_names.size(),
579 lib_def->ReachableDefinitions(*fdef));
580
581 GraphOptimizationPassOptions optimization_options;
582 // TODO(iga): Thread other relevant options from SessionOptions.
583 SessionOptions session_options;
584 session_options.env = env_;
585 session_options.config = options.config_proto;
586 optimization_options.session_options = &session_options;
587 optimization_options.graph = &graph;
588 optimization_options.flib_def = &data->overlay_lib_;
589 optimization_options.device_set = &device_set;
590
591 DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
592 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
593 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
594
595 DumpGraph("Before calling Placer", graph.get());
596 // Make the FunctionLibraryRuntime's device the default device if
597 // nothing else is hard coded. This allows the same function definition
598 // to be specialized to different devices depending on the
599 // PartitionedCallOp's device.
600 Device* default_device = nullptr;
601 if (!options.target.empty()) {
602 FunctionLibraryRuntime* flr = GetFLR(options.target);
603 if (flr == nullptr) {
604 return errors::InvalidArgument(
605 "Cannot instantiate multi-device function with target device ",
606 options.target);
607 }
608 default_device = flr->device();
609 }
610
611 // TODO(b/124993244): Smartly merge options in nested defuns, and raise
612 // exceptions/warnings in case where nested function call options are ignored.
613 Placer placer(graph.get(), &device_set, default_device,
614 options.config_proto.allow_soft_placement(),
615 options.config_proto.log_device_placement());
616 TF_RETURN_IF_ERROR(placer.Run());
617
618 DumpGraph("Before running POST_PLACEMENT passes", graph.get());
619 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
620 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
621
622 Device* cpu_device;
623 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
624
625 if (options.optimize_graph_fn) {
626 DumpGraph("Before running graph optimization fn", graph.get());
627 Status status = options.optimize_graph_fn(
628 std::move(ret_node_names), std::move(control_ret_node_names),
629 &data->overlay_lib_, device_set, cpu_device, &graph);
630 if (!status.ok()) {
631 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
632 << status.ToString();
633 }
634 DumpGraph("After optimization", graph.get());
635 }
636
637 DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
638 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
639 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
640 DumpGraph("After all optimization passes", graph.get());
641
642 if (options.graph_collector != nullptr) {
643 GraphDef def;
644 graph->ToGraphDef(&def);
645 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
646 options.graph_collector->CollectOptimizedGraph(def);
647 }
648
649 std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
650 TF_RETURN_IF_ERROR(
651 PartitionFunctionGraph(device_set, std::move(graph), &subgraphs));
652
653 if (options.graph_collector != nullptr) {
654 for (const auto& pair : subgraphs) {
655 GraphDef def;
656 pair.second->ToGraphDef(&def);
657 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
658 options.graph_collector->CollectPartitionedGraph(def);
659 }
660 }
661
662 int i = 0;
663 FunctionNameGenerator name_generator(&data->overlay_lib_, function_name);
664 for (const auto& pair : subgraphs) {
665 i += 1;
666 // TODO(iga): Fail gracefully if the set of devices corresponds
667 // to more than one address space.
668 const string& target = pair.first;
669 Graph* subgraph = pair.second.get();
670
671 ComponentFunctionData* comp_data = &data->glue_[target];
672 TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
673 subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_,
674 &comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_));
675 FunctionDef shard;
676 string unique_name = name_generator.GetName();
677 TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph, unique_name, &shard));
678 FunctionLibraryRuntime* target_flr = GetFLR(target);
679 TF_RETURN_IF_ERROR(data->overlay_lib_.AddFunctionDef(shard));
680 FunctionLibraryRuntime::InstantiateOptions opts;
681 opts.executor_type = options.executor_type;
682 opts.target = target;
683 opts.overlay_lib = &data->overlay_lib_;
684 FunctionLibraryRuntime::Handle component_handle;
685
686 TF_RETURN_IF_ERROR(target_flr->Instantiate(
687 unique_name, AttrSlice(&shard.attr()), opts, &component_handle));
688 VLOG(1) << "Instantiated component function " << unique_name
689 << " on device " << target << " with component handle "
690 << component_handle;
691 VLOG(2) << DebugString(shard);
692 comp_data->handle_ = component_handle;
693 }
694
695 *handle = AddMultiDeviceHandle(std::move(data), function_key);
696 VLOG(2) << "Instantiated MultiDevice function \"" << function_name
697 << "\" with handle " << *handle;
698 return Status::OK();
699 }
700
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const701 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
702 FunctionLibraryRuntime::Handle handle,
703 std::vector<Device*>* output_devices) const {
704 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
705 if (data == nullptr) {
706 return errors::InvalidArgument(
707 "Failed for find multi-device function handle ", handle);
708 }
709
710 for (const auto& pair : data->glue_) {
711 const ComponentFunctionData& comp_data = pair.second;
712 DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
713
714 const string& target = pair.first;
715 FunctionLibraryRuntime* target_flr = GetFLR(target);
716 Device* target_device = target_flr->device();
717 const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
718 DCHECK(fbody != nullptr);
719
720 output_devices->resize(data->num_outputs_);
721 for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
722 int ret_index = comp_data.ret_indices_[j];
723 if (fbody->ret_types[j] == DT_RESOURCE) {
724 (*output_devices)[ret_index] = target_device;
725 } else {
726 (*output_devices)[ret_index] =
727 comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
728 }
729 }
730 }
731
732 return Status::OK();
733 }
734
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const735 void ProcessFunctionLibraryRuntime::RunMultiDevice(
736 const FunctionLibraryRuntime::Options& opts,
737 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
738 std::vector<Tensor>* rets,
739 FunctionLibraryRuntime::DoneCallback done) const {
740 if (opts.create_rendezvous) {
741 // FLR->Run() is the default entry point. It checks for cancellation,
742 // creates rendezvous, etc.
743 // Letting create_rendezvous through will do the wrong thing - each
744 // component function will get a separate rendezvous created by its FLR.
745 done(
746 errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
747 "create_rendezvous=true. Please run the function "
748 "using FunctionLibraryRuntime::Run"));
749 return;
750 }
751
752 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
753 if (data == nullptr) {
754 done(
755 errors::InvalidArgument("Failed for find multi-device function handle ",
756 handle, ". Was the function instantiated?"));
757 return;
758 }
759
760 if (data->glue_.empty()) {
761 // Trivial case where the function body is empty.
762 done(Status::OK());
763 return;
764 }
765
766 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
767 for (int i = 0; i < data->glue_.size(); ++i) {
768 refcounted_done->Ref();
769 }
770
771 FunctionLibraryRuntime::Options opts_copy = opts;
772 for (const auto& pair : data->glue_) {
773 const string& target = pair.first;
774 const ComponentFunctionData& comp_data = pair.second;
775 FunctionLibraryRuntime::Handle handle = pair.second.handle_;
776 VLOG(1) << "Running function shard on device " << target << " with handle "
777 << handle;
778
779 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
780 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
781 opts_copy.remote_execution = false;
782 std::vector<Tensor> comp_args =
783 GetArgsForIndices(comp_data.arg_indices_, args);
784 std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
785 rets->resize(data->num_outputs_);
786 GetFLR(target)->Run(
787 opts_copy, handle, comp_args, comp_rets,
788 [comp_rets, rets, comp_data, refcounted_done](const Status& status) {
789 if (!status.ok()) {
790 LOG(ERROR) << "Component function execution failed: " << status;
791 refcounted_done->UpdateStatus(status);
792 } else {
793 for (int i = 0; i < comp_rets->size(); ++i) {
794 (*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
795 }
796 }
797 delete comp_rets;
798 // refcounted_done is thread-safe
799 refcounted_done->Unref();
800 });
801 }
802 refcounted_done->Unref();
803 }
804
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)805 Status ProcessFunctionLibraryRuntime::Instantiate(
806 const string& function_name, AttrSlice attrs,
807 const FunctionLibraryRuntime::InstantiateOptions& options,
808 FunctionLibraryRuntime::Handle* handle) {
809 if (options.is_multi_device_function) {
810 return InstantiateMultiDevice(function_name, attrs, options, handle);
811 }
812
813 *handle = kInvalidHandle;
814 FunctionLibraryRuntime* flr = GetFLR(options.target);
815 if (flr != nullptr) {
816 return flr->Instantiate(function_name, attrs, options, handle);
817 }
818 if (parent_ == nullptr) {
819 return errors::Internal(
820 "Currently don't support instantiating functions on device: ",
821 options.target);
822 }
823 VLOG(1) << "ProcessFLR Instantiate: " << function_name
824 << " on: " << options.target;
825 string function_key = Canonicalize(function_name, attrs, options);
826 FunctionData* f;
827 {
828 mutex_lock l(mu_);
829 FunctionLibraryRuntime::Handle h =
830 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
831 if (h == kInvalidHandle || function_data_.count(h) == 0) {
832 h = AddHandleLocked(function_key, options.target, kInvalidHandle);
833 }
834 f = function_data_[h].get();
835 *handle = h;
836 }
837 TF_RETURN_IF_ERROR(
838 f->DistributedInit(parent_, function_name, *lib_def_, attrs, options));
839 VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
840 << " on: " << options.target << " with handle: " << *handle
841 << " (this: " << this << ")";
842 return Status::OK();
843 }
844
RemoveHandle(FunctionLibraryRuntime::Handle handle)845 Status ProcessFunctionLibraryRuntime::RemoveHandle(
846 FunctionLibraryRuntime::Handle handle) {
847 mutex_lock l(mu_);
848 table_.erase(function_data_[handle]->function_key());
849 function_data_.erase(handle);
850 return Status::OK();
851 }
852
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)853 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
854 FunctionLibraryRuntime::Handle handle) {
855 std::unique_ptr<MultiDeviceFunctionData> mdata;
856 {
857 mutex_lock l(mu_);
858 auto it = mdevice_data_.find(handle);
859 --it->second->instantiation_counter_;
860 if (it->second->instantiation_counter_ != 0) {
861 return Status::OK();
862 }
863 mdata = std::move(it->second);
864 table_.erase(mdata->function_key_);
865 mdevice_data_.erase(it);
866 }
867
868 // If we are here we are releasing the last instantiation of `handle`.
869 // Release all component function handles.
870 Status overall_status;
871 for (const auto& it : mdata->glue_) {
872 const string& device = it.first;
873 FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
874 FunctionLibraryRuntime* flr = GetFLR(device);
875 if (flr == nullptr) {
876 return errors::InvalidArgument(
877 "Failed to find FunctionLibraryRuntime for device ", device,
878 " when releasing multi-device function handle ", handle);
879 }
880 Status status = flr->ReleaseHandle(flr_handle);
881 if (!status.ok()) {
882 overall_status = status;
883 }
884 }
885
886 return overall_status;
887 }
888
ReleaseHandle(FunctionLibraryRuntime::Handle handle)889 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
890 FunctionLibraryRuntime::Handle handle) {
891 if (IsMultiDevice(handle)) {
892 return ReleaseMultiDeviceHandle(handle);
893 }
894
895 FunctionLibraryRuntime* flr = nullptr;
896 string target_device;
897 {
898 mutex_lock l(mu_);
899 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
900 target_device = function_data_[handle]->target_device();
901 }
902 flr = GetFLR(target_device);
903 if (flr != nullptr) {
904 return flr->ReleaseHandle(handle);
905 }
906 return errors::InvalidArgument("Handle not found: ", handle);
907 }
908
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const909 void ProcessFunctionLibraryRuntime::Run(
910 const FunctionLibraryRuntime::Options& opts,
911 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
912 std::vector<Tensor>* rets,
913 FunctionLibraryRuntime::DoneCallback done) const {
914 bool multi_device;
915 {
916 tf_shared_lock l(mu_);
917 multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
918 }
919 if (multi_device) {
920 return RunMultiDevice(opts, handle, args, rets, done);
921 }
922
923 FunctionLibraryRuntime* flr = nullptr;
924 string target_device;
925 FunctionLibraryRuntime::LocalHandle local_handle;
926 {
927 tf_shared_lock l(mu_);
928 auto iter = function_data_.find(handle);
929 if (iter == function_data_.end()) {
930 done(errors::NotFound("Handle: ", handle, " not found."));
931 return;
932 }
933 FunctionData* function_data = iter->second.get();
934 target_device = function_data->target_device();
935 local_handle = function_data->local_handle();
936 }
937
938 if (!opts.remote_execution) {
939 done(
940 errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
941 "only be called for multi-device functions or "
942 "for remote execution."));
943 return;
944 }
945
946 flr = GetFLR(target_device);
947 if (flr != nullptr) {
948 auto rendezvous = opts.rendezvous;
949 string source_device = opts.source_device;
950 DeviceContext* device_context;
951 Status s = GetDeviceContext(source_device, &device_context);
952 if (!s.ok()) {
953 done(s);
954 return;
955 }
956 int64 src_incarnation, target_incarnation;
957 s = GetDeviceIncarnation(source_device, &src_incarnation);
958 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
959 if (!s.ok()) {
960 done(s);
961 return;
962 }
963
964 // Send the args over to the target device.
965 s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
966 device_context, opts.args_alloc_attrs, rendezvous);
967 if (!s.ok()) {
968 done(s);
969 return;
970 }
971 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
972 opts.rets_alloc_attrs;
973 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
974 flr->Run(opts, handle, args, remote_rets,
975 std::bind(
976 [source_device, target_device, target_incarnation, rendezvous,
977 device_context, rets_alloc_attrs, remote_rets,
978 rets](const Status& status,
979 FunctionLibraryRuntime::DoneCallback& done) {
980 if (!status.ok()) {
981 delete remote_rets;
982 done(status);
983 return;
984 }
985 int64 num_returns = remote_rets->size();
986 delete remote_rets;
987 // Now receive the return values from the target.
988 ReceiveTensorsAsync(target_device, source_device, "ret_",
989 target_incarnation, num_returns,
990 device_context, rets_alloc_attrs,
991 rendezvous, rets, std::move(done));
992 },
993 std::placeholders::_1, std::move(done)));
994 return;
995 }
996 if (parent_ != nullptr) {
997 parent_->Run(opts, local_handle, args, rets, std::move(done));
998 return;
999 }
1000 done(errors::Internal("Could not find device"));
1001 }
1002
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr) const1003 Status ProcessFunctionLibraryRuntime::Clone(
1004 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1005 CustomKernelCreator custom_kernel_creator,
1006 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1007 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) const {
1008 out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
1009 out_pflr->reset(new ProcessFunctionLibraryRuntime(
1010 device_mgr_, env, graph_def_version, out_lib_def->get(),
1011 optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
1012 parent_));
1013 return Status::OK();
1014 }
1015
1016 } // namespace tensorflow
1017