1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17 #include <iterator>
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/common_runtime/placer.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/notification.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 #include "tensorflow/core/util/ptr_util.h"
54 #include "tensorflow/core/util/reffed_status_callback.h"
55 #if !defined(IS_MOBILE_PLATFORM)
56 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
57 #endif // IS_MOBILE_PLATFORM
58
59 namespace tensorflow {
60
61 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
62
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)63 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
64 DistributedFunctionLibraryRuntime* parent, const string& function_name,
65 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
66 const FunctionLibraryRuntime::InstantiateOptions& options,
67 FunctionLibraryRuntime::DoneCallback done) {
68 {
69 mutex_lock l(mu_);
70 is_cross_process_ = true;
71 if (init_started_) {
72 init_done_.WaitForNotification();
73 done(init_result_);
74 return;
75 }
76 init_started_ = true;
77 }
78 parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
79 [this, done](const Status& s) {
80 init_done_.Notify();
81 done(s);
82 });
83 }
84
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)85 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
86 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
87 int graph_def_version, const FunctionLibraryDefinition* lib_def,
88 const OptimizerOptions& optimizer_options,
89 thread::ThreadPool* default_thread_pool,
90 DistributedFunctionLibraryRuntime* parent,
91 const SessionMetadata* session_metadata,
92 Rendezvous::Factory rendezvous_factory)
93 : parent_(parent),
94 env_(env),
95 config_(config ? absl::make_optional(*config) : absl::nullopt),
96 device_mgr_(device_mgr),
97 lib_def_(lib_def),
98 default_thread_pool_(default_thread_pool),
99 flr_map_(new std::unordered_map<Device*,
100 std::unique_ptr<FunctionLibraryRuntime>>),
101 next_handle_(0),
102 session_metadata_(session_metadata),
103 rendezvous_factory_(std::move(rendezvous_factory)),
104 optimizer_options_(optimizer_options),
105 graph_def_version_(graph_def_version) {
106 if (device_mgr == nullptr) {
107 (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
108 nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
109 graph_def_version, lib_def_, default_thread_pool, optimizer_options,
110 session_metadata_, this);
111 return;
112 }
113 InitializeDeviceAndFlr();
114 }
115
116 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)117 Status ProcessFunctionLibraryRuntime::SendTensors(
118 const string& source_device, const string& target_device,
119 const string& key_prefix, int64 src_incarnation,
120 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
121 const std::vector<AllocatorAttributes>& alloc_attrs,
122 RendezvousInterface* rendezvous) {
123 std::vector<string> keys;
124 for (int i = 0; i < tensors_to_send.size(); ++i) {
125 string name = strings::StrCat(key_prefix, i);
126 string key = Rendezvous::CreateKey(source_device, src_incarnation,
127 target_device, name, FrameAndIter(0, 0));
128 keys.push_back(key);
129 }
130 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
131 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
132 return Status::OK();
133 }
134
135 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,int64 num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)136 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
137 const string& source_device, const string& target_device,
138 const string& key_prefix, int64 src_incarnation, int64 num_tensors,
139 DeviceContext* device_context,
140 const std::vector<AllocatorAttributes>& alloc_attrs,
141 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
142 StatusCallback done) {
143 std::vector<string> keys;
144 for (int64 i = 0; i < num_tensors; ++i) {
145 string name = strings::StrCat(key_prefix, i);
146 string key = Rendezvous::CreateKey(source_device, src_incarnation,
147 target_device, name, FrameAndIter(0, 0));
148 keys.push_back(key);
149 }
150 RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
151 received_tensors, std::move(done));
152 }
153
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)154 Status ProcessFunctionLibraryRuntime::GetRetTypes(
155 FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
156 FunctionLibraryRuntime* flr = nullptr;
157 {
158 tf_shared_lock l(mu_);
159 auto miter = mdevice_data_.find(h);
160 if (miter != mdevice_data_.end()) {
161 *ret_types = miter->second->ret_types_;
162 return Status::OK();
163 }
164 auto fiter = function_data_.find(h);
165 if (fiter != function_data_.end()) {
166 flr = GetFLR(fiter->second->target_device());
167 }
168 }
169 if (flr != nullptr) {
170 return flr->GetRetTypes(h, ret_types);
171 }
172 return errors::InvalidArgument("Handle ", h, " not found.");
173 }
174
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const175 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
176 const string& device_name, int64* incarnation) const {
177 FunctionLibraryRuntime* flr = GetFLR(device_name);
178 if (flr == nullptr) {
179 return errors::InvalidArgument("Device name: ", device_name, " not found.");
180 }
181 *incarnation = flr->device()->attributes().incarnation();
182 return Status::OK();
183 }
184
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const185 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
186 const string& device_name, DeviceContext** device_context) const {
187 *device_context = nullptr;
188 FunctionLibraryRuntime* flr = GetFLR(device_name);
189 if (flr == nullptr) {
190 return errors::InvalidArgument("Device name: ", device_name, " not found.");
191 }
192 Device* device = flr->device();
193 string device_type = device->parsed_name().type;
194 if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
195 // "TPU_SYSTEM" indicates that `device` is a CPU.
196 return Status::OK();
197 }
198
199 if (device->IsRemoteCallAllowed()) {
200 auto* dev_info = flr->device()->tensorflow_gpu_device_info();
201 if (dev_info) {
202 *device_context = dev_info->default_context;
203 return Status::OK();
204 }
205 }
206
207 return errors::Internal("Device type: ", device_type,
208 " is currently unsupported for remote ",
209 "function executions");
210 }
211
InitializeDeviceAndFlr()212 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
213 DeviceMgr const* all_devices = device_mgr_;
214 if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
215 all_devices = parent_->remote_device_mgr();
216 }
217
218 mutex_lock l(mu_);
219 device_set_ = std::make_shared<DeviceSet>();
220 for (auto d : all_devices->ListDevices()) {
221 device_set_->AddDevice(d);
222 }
223 for (Device* d : device_mgr_->ListDevices()) {
224 if ((*flr_map_)[d] == nullptr) {
225 (*flr_map_)[d] = NewFunctionLibraryRuntime(
226 device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
227 graph_def_version_, lib_def_, default_thread_pool_,
228 optimizer_options_, session_metadata_, this);
229 }
230 }
231 }
232
GetFLR(const string & device_name) const233 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
234 const string& device_name) const {
235 Device* device = nullptr;
236 if (device_name != kDefaultFLRDevice) {
237 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
238 VLOG(4) << "Could not find device: " << device_name;
239 return nullptr;
240 }
241 }
242 const auto& iter = flr_map_->find(device);
243 if (iter == flr_map_->end()) {
244 VLOG(1) << "Could not find device: " << device_name
245 << "in the local process.";
246 return nullptr;
247 }
248 return iter->second.get();
249 }
250
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)251 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
252 const string& function_key, const string& device_name,
253 FunctionLibraryRuntime::LocalHandle local_handle) {
254 mutex_lock l(mu_);
255 return AddHandleLocked(function_key, device_name, local_handle);
256 }
257
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)258 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
259 const string& function_key, const string& device_name,
260 FunctionLibraryRuntime::LocalHandle local_handle) {
261 auto h = next_handle_;
262 function_data_[h] =
263 absl::make_unique<FunctionData>(device_name, local_handle, function_key);
264 table_[function_key] = h;
265 next_handle_++;
266 return h;
267 }
268
269 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)270 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
271 std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
272 mutex_lock l(mu_);
273 auto h = next_handle_;
274 mdevice_data_[h] = std::move(data);
275 table_[function_key] = h;
276 next_handle_++;
277 return h;
278 }
279
GetHandle(const string & function_key) const280 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
281 const string& function_key) const {
282 tf_shared_lock l(mu_);
283 return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
284 }
285
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const286 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
287 const string& device_name, FunctionLibraryRuntime::Handle handle) const {
288 return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
289 }
290
291 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const292 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
293 const string& device_name, FunctionLibraryRuntime::Handle handle,
294 bool include_multi_device) const {
295 tf_shared_lock l(mu_);
296
297 auto miter = mdevice_data_.find(handle);
298 if (miter != mdevice_data_.end()) {
299 if (!include_multi_device) return kInvalidLocalHandle;
300
301 const MultiDeviceFunctionData& data = *miter->second;
302 if (data.glue_.size() != 1) return kInvalidLocalHandle;
303
304 const auto& pair = *data.glue_.begin();
305 const string& func_device_name = pair.first;
306 const ComponentFunctionData& component_data = pair.second;
307 if (func_device_name != device_name) return kInvalidLocalHandle;
308
309 // Replace the given handle with the handle for the single component
310 // function.
311 handle = component_data.handle;
312 }
313
314 auto iter = function_data_.find(handle);
315 if (iter == function_data_.end()) {
316 return kInvalidLocalHandle;
317 }
318 FunctionData* function_data = iter->second.get();
319 if (function_data->target_device() != device_name) {
320 return kInvalidLocalHandle;
321 }
322 return function_data->local_handle();
323 }
324
GetDeviceName(FunctionLibraryRuntime::Handle handle) const325 string ProcessFunctionLibraryRuntime::GetDeviceName(
326 FunctionLibraryRuntime::Handle handle) const {
327 tf_shared_lock l(mu_);
328 auto iter = function_data_.find(handle);
329 CHECK(iter != function_data_.end());
330 FunctionData* function_data = iter->second.get();
331 return function_data->target_device();
332 }
333
334 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const335 ProcessFunctionLibraryRuntime::IsMultiDevice(
336 FunctionLibraryRuntime::Handle handle) const {
337 tf_shared_lock l(mu_);
338 const auto& it = mdevice_data_.find(handle);
339 if (it != mdevice_data_.end()) {
340 return it->second.get();
341 }
342 return nullptr;
343 }
344
345 namespace {
346 // Sets `group` to the first colocation group specified in `node`. If no
347 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)348 void GetColocationGroup(const Node* node, string* group) {
349 // We hoist the conversion from C-style string literal to string here,
350 // so that we can avoid the many repeated calls to strlen().
351 static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
352 const AttrValue* attr_value =
353 node->attrs().Find(kColocationAttrNameStringPiece);
354 if (attr_value != nullptr && attr_value->has_list() &&
355 attr_value->list().s_size() > 0) {
356 *group = attr_value->list().s(0);
357 }
358 }
359
AssignedOrRequestedDeviceName(const Node & node)360 const string* AssignedOrRequestedDeviceName(const Node& node) {
361 if (node.has_assigned_device_name()) {
362 return &node.assigned_device_name();
363 }
364 return &node.requested_device();
365 }
366
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)367 Status SetArgShape(
368 const std::unordered_map<int, DtypeAndPartialTensorShape>&
369 input_resource_dtypes_and_shapes,
370 const std::vector<Node*>& arg_nodes) {
371 for (Node* n : arg_nodes) {
372 int index;
373 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
374 DataType dtype;
375 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
376 if (dtype == DT_RESOURCE) {
377 auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
378 if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
379 AttrValue dtype_attr_value;
380 dtype_attr_value.mutable_list()->add_type(
381 dtype_and_shape_iter->second.dtype);
382 n->AddAttr("_handle_dtypes", dtype_attr_value);
383 TensorShapeProto shape_proto;
384 dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
385 AttrValue shape_attr_value;
386 *shape_attr_value.mutable_list()->add_shape() = shape_proto;
387 n->AddAttr("_handle_shapes", shape_attr_value);
388 }
389 }
390 }
391 return Status::OK();
392 }
393
394 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)395 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
396 std::vector<Tensor> tensors;
397 for (const auto& arg : args) {
398 if (arg.index() == 0) {
399 tensors.push_back(absl::get<Tensor>(arg));
400 }
401 }
402 return tensors;
403 }
404
405 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)406 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
407 std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
408 FunctionLibraryRuntime::DoneCallback done) {
409 return [rets, tensors, done = std::move(done)](const Status& s) {
410 if (s.ok()) {
411 for (const auto& t : *tensors) {
412 rets->push_back(t);
413 }
414 }
415 delete tensors;
416 done(s);
417 };
418 }
419
420 } // anonymous namespace
421
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,Device * default_device) const422 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
423 const std::vector<string>& input_devices,
424 const std::vector<string>& output_devices, const DeviceSet& device_set,
425 const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
426 Device* default_device) const {
427 // If output_devices are not specified, we want to set the output device
428 // based on the device of the output producing node. The output producing
429 // node can be an arg node because functions can simply return their
430 // arguments. To make sure that the output producing nodes have assigned
431 // devices, we assign them to arguments first.
432 for (Node* node : arg_nodes) {
433 const AttrValue* attr_value;
434 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
435 int64 index = attr_value->i();
436 node->set_assigned_device_name(input_devices[index]);
437 }
438
439 for (Node* node : ret_nodes) {
440 if (output_devices.empty()) {
441 DataType dtype;
442 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
443
444 VLOG(3) << "Trying to determine device for node " << node->name()
445 << "[T=" << DataTypeString(dtype) << "]";
446
447 // If output_devices are empty, the node producing retval
448 // must have explicitly assigned device or a colocation constraint
449 // to a node with explicitly assigned device.
450 for (const auto& it : node->in_edges()) {
451 if (it->IsControlEdge()) continue;
452
453 Node* src_node = it->src();
454 const string* src_device = AssignedOrRequestedDeviceName(*src_node);
455 string colocation_group = "";
456 GetColocationGroup(src_node, &colocation_group);
457 VLOG(3) << "Considering src: " << src_node->name()
458 << " src_device: " << *src_device
459 << " colo group: " << colocation_group;
460 while (src_device->empty() && colocation_group.empty() &&
461 src_node->IsIdentity()) {
462 // Only follows the real data input of Identity, not control edges.
463 Node* input_node;
464 TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
465 src_node = input_node;
466
467 src_device = AssignedOrRequestedDeviceName(*src_node);
468 GetColocationGroup(src_node, &colocation_group);
469 VLOG(3) << "Considering src: " << src_node->name()
470 << " src_device: " << *src_device
471 << " colo group: " << colocation_group;
472 }
473
474 // If resource is produced by a function call node, we can't trust
475 // source node device assignment, because multi-device functions can
476 // return resource placed on multiple devices. In such case we leave
477 // retval device assignment empty, and rely on placer to infer correct
478 // assignment based on actual output device.
479 const bool can_use_src_node_device =
480 !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def_, *src_node));
481
482 if (!colocation_group.empty()) {
483 AttrValue::ListValue colo_attr;
484 colo_attr.add_s(colocation_group);
485 std::vector<string> colo_slice = {colocation_group};
486 node->AddAttr(kColocationAttrName, colo_slice);
487 } else if (!src_device->empty() && can_use_src_node_device) {
488 // src_device can be a partially specified device. Find the
489 // matching device in the device_set.
490 DeviceNameUtils::ParsedName parsed;
491 if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
492 return errors::InvalidArgument(
493 "Failed to parse explicit device specification ", *src_device);
494 }
495 std::vector<Device*> matching_devices;
496 device_set.FindMatchingDevices(parsed, &matching_devices);
497 if (matching_devices.empty()) {
498 if (default_device != nullptr) {
499 matching_devices.push_back(default_device);
500 } else {
501 return errors::InvalidArgument(
502 "Unable to find any devices for spec ", *src_device);
503 }
504 } else if (matching_devices.size() != 1) {
505 bool on_same_task = true;
506 for (int i = 1; i < matching_devices.size(); ++i) {
507 if (!DeviceNameUtils::IsSameAddressSpace(
508 matching_devices.at(0)->parsed_name(),
509 matching_devices.at(i)->parsed_name())) {
510 on_same_task = false;
511 break;
512 }
513 }
514 // If the src node of an output is assigned to a address space (e.g.
515 // py_func), rely on placer to assign a device to the output.
516 if (on_same_task) {
517 continue;
518 }
519 // Compare with default_device if it has a narrower scope matching
520 // requested device.
521 int colocated_on_default_device = 0;
522 for (int i = 0; i < matching_devices.size(); ++i) {
523 if (DeviceNameUtils::IsSameAddressSpace(
524 default_device->parsed_name(),
525 matching_devices.at(i)->parsed_name())) {
526 colocated_on_default_device++;
527 }
528 }
529 // Continue to raise error if multiple colocated devices are
530 // found.
531 if (colocated_on_default_device == 1) {
532 continue;
533 }
534
535 // Convert a vector of devices to a string.
536 // Using absl::StrJoin did not work in Android builds.
537 string devices = "[";
538 for (Device* device : matching_devices) {
539 devices.append(device->name());
540 devices.append(", ");
541 }
542 if (devices.size() > 2) {
543 devices.resize(devices.size() - 2);
544 }
545 devices.append("]");
546
547 return errors::InvalidArgument(
548 *src_device,
549 "When FunctionLibraryRuntime::Options.output_devices are "
550 "not specified for a multi-device function, the device "
551 "specification on the output node must match exactly one "
552 "device. Matched devices are ",
553 devices);
554 }
555 VLOG(3) << "Setting output device to " << matching_devices[0]->name()
556 << " for node " << SummarizeNode(*node);
557 node->set_assigned_device_name(matching_devices[0]->name());
558 } else if (!src_device->empty() && !can_use_src_node_device) {
559 VLOG(3) << "Did not set device for a resource output node "
560 << SummarizeNode(*node);
561 }
562 }
563 } else {
564 const AttrValue* attr_value;
565 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
566 int64 index = attr_value->i();
567 // output_devices size is checked in InstantiateMultiDevice
568 DCHECK_GT(output_devices.size(), index);
569 VLOG(3) << "Setting output device to " << output_devices[index]
570 << " for return at index " << index;
571 node->set_assigned_device_name(output_devices[index]);
572 }
573 }
574 return Status::OK();
575 }
576
577 namespace {
578
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)579 Status ValidateNoListArguments(
580 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
581 const string& function_name) {
582 for (const OpDef::ArgDef& arg : args) {
583 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
584 return errors::InvalidArgument(
585 "Function ", function_name, " has an ", arg_type, " named \"",
586 arg.name(),
587 "\" that is a list of tensors."
588 " Multi-device functions support only single-tensor inputs "
589 " and outputs");
590 }
591 }
592 return Status::OK();
593 }
594
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)595 Status ValidateMultiDeviceOptions(
596 const FunctionDef& fdef,
597 const FunctionLibraryRuntime::InstantiateOptions& options) {
598 const OpDef& signature = fdef.signature();
599 // Multi-device functions currently do not support list inputs or outputs.
600 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
601 signature.name()));
602 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
603 signature.name()));
604 if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
605 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
606 return errors::Unimplemented(
607 "Function '", signature.name(), "' has `",
608 FunctionLibraryDefinition::kIntsOnDeviceAttr,
609 "` attribute set. This attribute is not currently supported by "
610 "multi-device functions.");
611 }
612 if (options.input_devices.size() != signature.input_arg_size()) {
613 return errors::InvalidArgument(
614 "InstantiateOptions.input_devices must have the same length "
615 "as the number of arguments: input_devices length = ",
616 options.input_devices.size(),
617 " number of arguments = ", signature.input_arg_size());
618 }
619 if (!options.output_devices.empty() &&
620 options.output_devices.size() != signature.output_arg_size()) {
621 return errors::InvalidArgument(
622 "InstantiateOptions.output_devices must either be empty or have the "
623 "same length as the number of arguments: output_devices length = ",
624 options.output_devices.size(),
625 " number of arguments = ", signature.output_arg_size());
626 }
627 return Status::OK();
628 }
629
630 } // anonymous namespace
631
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)632 Status GetGraphAndArgRets(
633 const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
634 const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
635 std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
636 std::vector<string>* ret_node_names, DataTypeVector* ret_types,
637 std::vector<string>* control_ret_node_names) {
638 std::unique_ptr<FunctionBody> fbody;
639 // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
640 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
641 if (!fbody) {
642 LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
643 return errors::Internal("Failed to construct FunctionBody for ",
644 function_name);
645 }
646 *graph = std::unique_ptr<Graph>(fbody->graph);
647 arg_nodes->reserve(fbody->arg_nodes.size());
648 std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
649 std::back_inserter(*arg_nodes));
650 ret_nodes->reserve(fbody->ret_nodes.size());
651 std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
652 std::back_inserter(*ret_nodes));
653 fbody->graph = nullptr;
654 ret_node_names->reserve(fbody->ret_nodes.size());
655 for (const Node* node : fbody->ret_nodes) {
656 ret_node_names->push_back(node->name());
657 }
658 for (const auto& ret_type : fbody->ret_types) {
659 ret_types->push_back(ret_type);
660 }
661 control_ret_node_names->reserve(fbody->control_ret_nodes.size());
662 for (const Node* node : fbody->control_ret_nodes) {
663 control_ret_node_names->push_back(node->name());
664 }
665 return Status::OK();
666 }
667
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)668 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
669 const string& function_name, AttrSlice attrs,
670 const FunctionLibraryRuntime::InstantiateOptions& options,
671 FunctionLibraryRuntime::Handle* handle) {
672 // Check if this function has already been instantiated.
673 const string& function_key = Canonicalize(function_name, attrs, options);
674
675 {
676 mutex_lock l(mu_);
677 const auto& it = table_.find(function_key);
678 if (it != table_.end()) {
679 *handle = it->second;
680 ++mdevice_data_[*handle]->instantiation_counter_;
681 return Status::OK();
682 }
683 }
684
685 VLOG(1) << "Instantiating MultiDevice function \"" << function_name
686 << "\" on default device \"" << options.target << "\"";
687 if (VLOG_IS_ON(3)) {
688 int index = 0;
689 VLOG(3) << "Requested input devices:";
690 for (const string& device : options.input_devices) {
691 VLOG(3) << " [input " << index++ << "] " << device;
692 }
693 index = 0;
694 VLOG(3) << "Requested output devices:";
695 for (const string& device : options.output_devices) {
696 VLOG(3) << " [output " << index++ << "] " << device;
697 }
698 }
699
700 const FunctionLibraryDefinition* lib_def =
701 options.lib_def == nullptr ? lib_def_ : options.lib_def;
702
703 const FunctionDef* fdef = lib_def->Find(function_name);
704 if (fdef == nullptr) {
705 return errors::InvalidArgument("Failed to find function \"", function_name,
706 "\" in function library: ", lib_def);
707 }
708
709 TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
710
711 std::unique_ptr<Graph> graph;
712 std::vector<Node*> arg_nodes, ret_nodes;
713 std::vector<string> ret_node_names;
714 DataTypeVector ret_types;
715 std::vector<string> control_ret_node_names;
716
717 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
718 function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
719 &ret_node_names, &ret_types, &control_ret_node_names));
720
721 if (options.graph_collector != nullptr) {
722 GraphDef def;
723 graph->ToGraphDef(&def);
724 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
725 options.graph_collector->CollectRawGraph(def);
726 }
727
728 Device* default_device = nullptr;
729 if (options.default_device_to_target && !options.target.empty()) {
730 // Make the `target` device the default device if nothing else is hard
731 // coded. This allows the same function definition to be specialized to
732 // different devices depending on the `PartitionedCallOp` device.
733 FunctionLibraryRuntime* flr = GetFLR(options.target);
734 if (flr == nullptr) {
735 return errors::InvalidArgument(
736 "Cannot instantiate multi-device function with target device ",
737 options.target);
738 }
739 default_device = flr->device();
740 }
741 const std::shared_ptr<DeviceSet> dev_set = device_set();
742
743 TF_RETURN_IF_ERROR(
744 SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
745 TF_RETURN_IF_ERROR(PinArgsAndRets(
746 options.input_devices, options.output_devices, *dev_set, arg_nodes,
747 ret_nodes,
748 options.config_proto.allow_soft_placement() ? default_device : nullptr));
749
750 auto data = absl::make_unique<MultiDeviceFunctionData>(
751 function_name, function_key, ret_node_names.size(),
752 lib_def->ReachableDefinitions(*fdef), std::move(ret_types));
753
754 // Do not run function/graph optimization passes for component functions,
755 // since they have already processed the main function.
756 const bool should_run_optimization_passes = !options.is_component_function;
757 if (!should_run_optimization_passes) {
758 VLOG(1) << "Skipping function/graph optimization passes when instantiating "
759 "component function "
760 << function_name;
761 }
762
763 // Mapping from a function body node name to the control output name.
764 std::unordered_map<string, string> node_name_to_control_ret;
765
766 bool control_rets_updated = false;
767 if (should_run_optimization_passes) {
768 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
769 *dev_set, options.config_proto, &graph, &data->lib_def_,
770 &control_ret_node_names, &control_rets_updated));
771 }
772
773 if (control_rets_updated) {
774 // Function graph pass may have resulted in different nodes/node names for
775 // control rets.
776 for (const auto& control_ret : control_ret_node_names) {
777 node_name_to_control_ret.emplace(control_ret, control_ret);
778 }
779 } else {
780 for (const auto& control_ret : fdef->control_ret()) {
781 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
782 }
783 }
784
785 GraphOptimizationPassOptions optimization_options;
786 // TODO(iga): Thread other relevant options from SessionOptions.
787 SessionOptions session_options;
788 session_options.env = env_;
789 session_options.config = options.config_proto;
790 optimization_options.session_options = &session_options;
791 optimization_options.graph = &graph;
792 optimization_options.flib_def = &data->lib_def_;
793 optimization_options.device_set = dev_set.get();
794 optimization_options.is_function_graph = true;
795
796 DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
797 if (should_run_optimization_passes) {
798 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
799 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
800 }
801
802 // TODO(b/124993244): Smartly merge options in nested defuns, and raise
803 // exceptions/warnings in case where nested function call options are ignored.
804 DumpGraph("Before calling Placer", graph.get());
805 Placer placer(graph.get(), function_name, optimization_options.flib_def,
806 dev_set.get(), default_device,
807 options.config_proto.allow_soft_placement(),
808 options.config_proto.log_device_placement());
809 TF_RETURN_IF_ERROR(placer.Run());
810
811 DumpGraph("Before running POST_PLACEMENT passes", graph.get());
812 if (should_run_optimization_passes) {
813 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
814 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
815 }
816
817 Device* cpu_device;
818 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
819
820 if (options.optimize_graph_fn) {
821 DumpGraph("Before running graph optimization fn", graph.get());
822 Status status = options.optimize_graph_fn(
823 std::move(ret_node_names), std::move(control_ret_node_names),
824 &data->lib_def_, *dev_set, cpu_device, &graph);
825 if (!status.ok()) {
826 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
827 << status.ToString();
828 }
829 DumpGraph("After optimization", graph.get());
830 }
831
832 DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
833 if (should_run_optimization_passes) {
834 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
835 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
836 }
837
838 // Expand the nodes assigned to a CompositeDevice before graph partition to
839 // avoid generating a subgraph on a virtual device for execution.
840 // This transformation should happen as late as possible, in order to run as
841 // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
842 // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
843 TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
844 options.composite_devices, graph.get()));
845
846 if (options.graph_collector != nullptr) {
847 GraphDef def;
848 graph->ToGraphDef(&def);
849 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
850 options.graph_collector->CollectOptimizedGraph(def);
851 }
852
853 VLOG(4) << "Main function graph to be partitioned:";
854 VLOG(4) << DebugString(graph->ToGraphDefDebug());
855
856 std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
857 TF_RETURN_IF_ERROR(
858 PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
859
860 for (const auto& pair : subgraphs) {
861 DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
862 pair.first, ")"),
863 pair.second.get());
864 }
865 optimization_options.graph = nullptr;
866 optimization_options.device_set = nullptr;
867 optimization_options.partition_graphs = &subgraphs;
868 // Normally POST_PARTITIONING passes are run by distributed workers.
869 // Distributed workers are currently not supported in this code path, so we
870 // run the passes here.
871 if (should_run_optimization_passes) {
872 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
873 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
874 }
875 for (const auto& pair : subgraphs) {
876 const auto* optimized_subgraph = pair.second.get();
877 DumpGraph(
878 strings::StrCat("After all optimization passes (", pair.first, ")"),
879 optimized_subgraph);
880 if (VLOG_IS_ON(1)) {
881 DumpGraphDefToFile(
882 strings::StrCat("pflr_after_all_optimization_passes_",
883 reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
884 pair.first),
885 optimized_subgraph->ToGraphDefDebug());
886 }
887 }
888
889 if (options.graph_collector != nullptr) {
890 for (const auto& pair : subgraphs) {
891 GraphDef def;
892 pair.second->ToGraphDef(&def);
893 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
894 options.graph_collector->CollectPartitionedGraph(def);
895 }
896 }
897
898 // We must preserve control returns in each of the function components,
899 // otherwise after function inlining we might prune side-effectful nodes.
900 const auto control_ret =
901 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
902 const auto it = node_name_to_control_ret.find(n->name());
903 return it != node_name_to_control_ret.end()
904 ? absl::make_optional<string>(it->second)
905 : absl::nullopt;
906 };
907
908 int i = 0;
909 // Generate a random function_name to avoid one function reuse the partition
910 // function instantiated by another function.
911 FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
912 FunctionNameGenerator name_generator(
913 data_lib_def, absl::StrCat(function_name, "_", random::New64()));
914 auto subgraph_size = subgraphs.size();
915 gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
916 BlockingCounter counter(static_cast<int>(subgraph_size));
917 auto runner = [this, subgraph_size](std::function<void()> fn) {
918 // NOTE: Only use thread pool to instantiate sub-function when there are
919 // more than 8 sub-functions. We want to avoid cost of switching thread when
920 // there are only a few sub-functions.
921 if (default_thread_pool_ != nullptr && subgraph_size > 8) {
922 default_thread_pool_->Schedule(fn);
923 } else {
924 fn();
925 }
926 };
927 for (const auto& pair : subgraphs) {
928 Status* status = &instantiate_status[i];
929 string unique_name = name_generator.GetName();
930 ComponentFunctionData* comp_data = &data->glue_[pair.first];
931 runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
932 &control_ret, &options, status, &counter, &data] {
933 const string& target = pair.first;
934
935 const string& device_type =
936 dev_set->FindDeviceByName(target)->device_type();
937 Graph* subgraph = pair.second.get();
938
939 status->Update(UpdateArgAndRetvalMetadata(
940 subgraph, device_type, &comp_data->arg_indices,
941 &comp_data->ret_indices, &comp_data->arg_alloc_attrs,
942 &comp_data->ret_alloc_attrs));
943 if (!status->ok()) {
944 counter.DecrementCount();
945 return;
946 }
947 FunctionDef shard;
948 status->Update(
949 GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
950 if (!status->ok()) {
951 counter.DecrementCount();
952 return;
953 }
954 status->Update(data_lib_def->AddFunctionDef(shard));
955 if (!status->ok()) {
956 counter.DecrementCount();
957 return;
958 }
959 FunctionLibraryRuntime::InstantiateOptions opts;
960 opts.executor_type = options.executor_type;
961 opts.target = target;
962 opts.lib_def = data_lib_def;
963 opts.create_kernels_eagerly = options.create_kernels_eagerly;
964 opts.state_handle = options.state_handle;
965 auto attrs = AttrSlice(&shard.attr());
966 VLOG(1) << "Start instantiating component function " << unique_name
967 << " on device " << target;
968 VLOG(4) << DebugString(shard);
969
970 auto* component_handle = new FunctionLibraryRuntime::Handle;
971 auto done = [this, status, unique_name, comp_data, component_handle,
972 &data, &counter](const Status& s) {
973 status->Update(s);
974
975 VLOG(1) << "Finished instantiating component function " << unique_name
976 << " with handle " << *component_handle << " status: " << s;
977 if (status->ok()) {
978 {
979 mutex_lock l(mu_);
980 if (function_data_[*component_handle]->is_cross_process()) {
981 data->is_cross_process_ = true;
982 }
983 }
984 comp_data->handle = *component_handle;
985 }
986 delete component_handle;
987 counter.DecrementCount();
988 };
989
990 FunctionLibraryRuntime* flr = GetFLR(opts.target);
991 if (flr != nullptr) {
992 // Initialize local function synchronously.
993 Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
994 done(s);
995 } else {
996 opts.ret_indices = comp_data->ret_indices;
997 // Initialize remote function asynchronously.
998 InstantiateRemote(unique_name, attrs, opts, component_handle, done);
999 }
1000 });
1001 i += 1;
1002 }
1003 counter.Wait();
1004 StatusGroup group;
1005 for (auto& status : instantiate_status) {
1006 group.Update(status);
1007 }
1008 TF_RETURN_IF_ERROR(group.as_summary_status());
1009
1010 *handle = AddMultiDeviceHandle(std::move(data), function_key);
1011 VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1012 << "\" with handle " << *handle;
1013 return Status::OK();
1014 }
1015
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1016 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1017 FunctionLibraryRuntime::Handle handle,
1018 std::vector<Device*>* output_devices) const {
1019 MultiDeviceFunctionData* data = IsMultiDevice(handle);
1020 if (data == nullptr) {
1021 return errors::InvalidArgument(
1022 "Failed for find multi-device function handle ", handle);
1023 }
1024
1025 for (const auto& pair : data->glue_) {
1026 const ComponentFunctionData& comp_data = pair.second;
1027 DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1028 if (comp_data.ret_indices.empty()) {
1029 continue;
1030 }
1031
1032 const string& target = pair.first;
1033 FunctionLibraryRuntime* target_flr = GetFLR(target);
1034 Device* target_device = nullptr;
1035 Device* host = nullptr;
1036 if (target_flr == nullptr) {
1037 if (!data->has_remote_outputs) {
1038 data->has_remote_outputs = true;
1039 }
1040 target_device = device_set()->FindDeviceByName(target);
1041 string remote_host;
1042 TF_RETURN_IF_ERROR(
1043 DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1044 host = device_set()->FindDeviceByName(remote_host);
1045 } else {
1046 target_device = target_flr->device();
1047 }
1048 output_devices->resize(data->num_outputs_);
1049 for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1050 int ret_index = comp_data.ret_indices[j];
1051 if (data->ret_types_[ret_index] == DT_RESOURCE) {
1052 (*output_devices)[ret_index] = target_device;
1053 } else {
1054 (*output_devices)[ret_index] =
1055 comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1056 }
1057 }
1058 }
1059
1060 return Status::OK();
1061 }
1062
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1063 void ProcessFunctionLibraryRuntime::RunMultiDevice(
1064 const FunctionLibraryRuntime::Options& opts,
1065 FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
1066 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1067 FunctionLibraryRuntime::DoneCallback done,
1068 std::function<Status(const ComponentFunctionData& comp_data,
1069 InternalArgs* args)>
1070 get_component_args) const {
1071 if (opts.create_rendezvous) {
1072 // FLR->Run() is the default entry point. It checks for cancellation,
1073 // creates rendezvous, etc.
1074 // Letting create_rendezvous through will do the wrong thing - each
1075 // component function will get a separate rendezvous created by its FLR.
1076 done(
1077 errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
1078 "create_rendezvous=true. Please run the function "
1079 "using FunctionLibraryRuntime::Run"));
1080 return;
1081 }
1082
1083 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1084 if (data == nullptr) {
1085 done(errors::NotFound("Multi-device function handle ", handle,
1086 "not found. Was the function instantiated?"));
1087 return;
1088 }
1089
1090 VLOG(1) << "Running multi-device function " << data->function_name_;
1091 VLOG(4) << " with " << opts.DebugString();
1092
1093 if (data->glue_.empty()) {
1094 // Trivial case where the function body is empty.
1095 done(Status::OK());
1096 return;
1097 }
1098
1099 // Check whether we have the right rendezvous.
1100 if (opts.rendezvous && data->is_cross_process_ &&
1101 !opts.rendezvous->is_cross_process()) {
1102 done(errors::InvalidArgument(
1103 "Running a cross process function ", data->function_name_,
1104 " without an appropriate cross process Rendezvous."));
1105 return;
1106 }
1107
1108 // A locally created cancellation manager, used only when the caller does not
1109 // provide one in argument.
1110 std::shared_ptr<CancellationManager> local_cm;
1111 CancellationManager* cm = opts.cancellation_manager;
1112 if (cm == nullptr) {
1113 local_cm = std::make_shared<CancellationManager>();
1114 cm = local_cm.get();
1115 }
1116
1117 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1118 for (int i = 0; i < data->glue_.size(); ++i) {
1119 refcounted_done->Ref();
1120 }
1121
1122 FunctionLibraryRuntime::Options opts_copy = opts;
1123 for (const auto& pair : data->glue_) {
1124 const string& target = pair.first;
1125 const ComponentFunctionData& comp_data = pair.second;
1126 FunctionLibraryRuntime::Handle handle = pair.second.handle;
1127
1128 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1129 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1130 opts_copy.cancellation_manager = cm;
1131
1132 InternalArgs comp_args;
1133 Status s = get_component_args(comp_data, &comp_args);
1134 if (!s.ok()) {
1135 VLOG(2) << "Failed to get component function arguments: " << s;
1136 refcounted_done->UpdateStatus(s);
1137 refcounted_done->Unref();
1138 cm->StartCancel();
1139 continue;
1140 }
1141 std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1142 rets->resize(data->num_outputs_);
1143
1144 auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1145 cm, local_cm, data, handle,
1146 target](const Status& status) {
1147 if (!status.ok()) {
1148 VLOG(2) << "Component function execution on target " << target
1149 << " from " << data->function_name_ << " with handle " << handle
1150 << " failed: " << status;
1151 const string function_and_msg = strings::StrCat(
1152 errors::FormatFunctionForError(data->function_name_), " ",
1153 status.error_message());
1154 refcounted_done->UpdateStatus(Status(status.code(), function_and_msg));
1155 // Cancel the execution of other component functions.
1156 cm->StartCancel();
1157 } else {
1158 VLOG(2) << "Component function execution on target " << target
1159 << " from " << data->function_name_ << " with handle " << handle
1160 << " succeeded.";
1161 for (int i = 0; i < comp_rets->size(); ++i) {
1162 (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1163 }
1164 }
1165 delete comp_rets;
1166 // refcounted_done is thread-safe
1167 refcounted_done->Unref();
1168 };
1169
1170 FunctionLibraryRuntime* flr = GetFLR(target);
1171 if (flr != nullptr) {
1172 opts_copy.remote_execution = false;
1173 // When target device has private thread pool, use the target device
1174 // runner
1175 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1176 opts_copy.runner = (pool == nullptr) ? opts_copy.runner : flr->runner();
1177
1178 VLOG(1) << "Running component function on device " << target << " from "
1179 << data->function_name_ << " with handle " << handle;
1180 VLOG(4) << " with " << opts_copy.DebugString();
1181
1182 std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1183 flr->Run(
1184 opts_copy, handle, GetLocalArgs(comp_args.args), comp_tensor_rets,
1185 TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1186 std::move(component_fn_callback)));
1187 } else {
1188 opts_copy.remote_execution = true;
1189
1190 VLOG(1) << "Running component function on device " << target << " from "
1191 << data->function_name_ << " with handle " << handle;
1192 VLOG(4) << " with " << opts_copy.DebugString();
1193
1194 RunInternal(opts_copy, handle, comp_args.args, comp_rets, cleanup_items,
1195 std::move(component_fn_callback));
1196 }
1197 }
1198 refcounted_done->Unref();
1199 }
1200
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1201 Status ProcessFunctionLibraryRuntime::Instantiate(
1202 const string& function_name, AttrSlice attrs,
1203 const FunctionLibraryRuntime::InstantiateOptions& options,
1204 FunctionLibraryRuntime::Handle* handle) {
1205 if (options.is_multi_device_function) {
1206 return InstantiateMultiDevice(function_name, attrs, options, handle);
1207 }
1208
1209 *handle = kInvalidHandle;
1210 FunctionLibraryRuntime* flr = GetFLR(options.target);
1211 if (flr != nullptr) {
1212 return flr->Instantiate(function_name, attrs, options, handle);
1213 }
1214
1215 Status status;
1216 Notification notification;
1217 InstantiateRemote(function_name, attrs, options, handle,
1218 [&status, ¬ification](const Status& s) {
1219 status = s;
1220 notification.Notify();
1221 });
1222 notification.WaitForNotification();
1223 return status;
1224 }
1225
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1226 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1227 FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1228 tf_shared_lock l(mu_);
1229 const auto& mdevice_it = mdevice_data_.find(handle);
1230 if (mdevice_it != mdevice_data_.end()) {
1231 *is_cross_process = mdevice_it->second->is_cross_process_;
1232 return Status::OK();
1233 }
1234 const auto& it = function_data_.find(handle);
1235 if (it != function_data_.end()) {
1236 *is_cross_process = it->second->is_cross_process();
1237 return Status::OK();
1238 }
1239 return errors::InvalidArgument("Handle ", handle, " not found.");
1240 }
1241
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1242 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1243 const string& function_name, AttrSlice attrs,
1244 const FunctionLibraryRuntime::InstantiateOptions& options,
1245 FunctionLibraryRuntime::Handle* handle,
1246 FunctionLibraryRuntime::DoneCallback done) {
1247 if (parent_ == nullptr) {
1248 done(errors::Internal(
1249 "Currently don't support instantiating functions on device: ",
1250 options.target));
1251 return;
1252 }
1253 auto target = options.target;
1254 VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1255 string function_key = Canonicalize(function_name, attrs, options);
1256 FunctionData* f;
1257 {
1258 mutex_lock l(mu_);
1259 FunctionLibraryRuntime::Handle h =
1260 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1261 if (h == kInvalidHandle || function_data_.count(h) == 0) {
1262 h = AddHandleLocked(function_key, target, kInvalidHandle);
1263 }
1264 f = function_data_[h].get();
1265 *handle = h;
1266 }
1267 f->DistributedInit(
1268 parent_, function_name,
1269 options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1270 [this, function_name, target, handle, done](const Status& s) {
1271 VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1272 << " on: " << target << " with handle: " << *handle
1273 << " (this: " << this << ")";
1274 done(s);
1275 });
1276 }
1277
RemoveHandle(FunctionLibraryRuntime::Handle handle)1278 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1279 FunctionLibraryRuntime::Handle handle) {
1280 mutex_lock l(mu_);
1281 table_.erase(function_data_[handle]->function_key());
1282 function_data_.erase(handle);
1283 return Status::OK();
1284 }
1285
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1286 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1287 FunctionLibraryRuntime::Handle handle) {
1288 std::unique_ptr<MultiDeviceFunctionData> mdata;
1289 {
1290 mutex_lock l(mu_);
1291 auto it = mdevice_data_.find(handle);
1292 --it->second->instantiation_counter_;
1293 if (it->second->instantiation_counter_ != 0) {
1294 return Status::OK();
1295 }
1296 mdata = std::move(it->second);
1297 table_.erase(mdata->function_key_);
1298 mdevice_data_.erase(it);
1299 }
1300
1301 // If we are here we are releasing the last instantiation of `handle`.
1302 // Release all component function handles.
1303 Status overall_status;
1304 for (const auto& it : mdata->glue_) {
1305 const string& device = it.first;
1306 FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1307 FunctionLibraryRuntime* flr = GetFLR(device);
1308 if (flr == nullptr) {
1309 // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1310 // parent is not null.
1311 if (parent_ != nullptr) {
1312 return errors::Unimplemented(
1313 "Releasing a multi-device component handle on a remote device is "
1314 "not yet implemented.");
1315 }
1316 return errors::InvalidArgument(
1317 "Failed to find FunctionLibraryRuntime for device ", device,
1318 " when releasing multi-device function handle ", handle);
1319 }
1320 Status status = flr->ReleaseHandle(flr_handle);
1321 if (!status.ok()) {
1322 overall_status = status;
1323 }
1324 }
1325
1326 return overall_status;
1327 }
1328
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1329 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1330 FunctionLibraryRuntime::Handle handle) {
1331 // Return directly if all function handles has already been released.
1332 if (flr_map_ == nullptr) return Status::OK();
1333
1334 if (IsMultiDevice(handle)) {
1335 return ReleaseMultiDeviceHandle(handle);
1336 }
1337
1338 FunctionLibraryRuntime* flr = nullptr;
1339 string target_device;
1340 {
1341 mutex_lock l(mu_);
1342 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1343 target_device = function_data_[handle]->target_device();
1344 }
1345 flr = GetFLR(target_device);
1346 if (flr != nullptr) {
1347 return flr->ReleaseHandle(handle);
1348 }
1349 return errors::InvalidArgument("Handle not found: ", handle);
1350 }
1351
1352 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64 step_id,const Rendezvous * created_rendezvous) const1353 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1354 std::vector<std::unique_ptr<CleanUpItem>>* items,
1355 FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
1356 const Rendezvous* created_rendezvous) const {
1357 return
1358 [this, items, done = std::move(done), step_id,
1359 created_rendezvous](const Status& status) {
1360 if (created_rendezvous) {
1361 DCHECK(rendezvous_factory_);
1362 created_rendezvous->Unref();
1363 Status s = rendezvous_factory_.CleanUp(step_id);
1364 if (!s.ok()) {
1365 LOG(ERROR) << s;
1366 }
1367 }
1368 auto* local_status = new Status(status);
1369 CleanUp(items, [local_status, done](const Status& cleanup_status) {
1370 local_status->Update(cleanup_status);
1371 done(*local_status);
1372 delete local_status;
1373 });
1374 delete items;
1375 };
1376 }
1377
CreateRendezvous(const FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1378 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1379 const FunctionLibraryRuntime::Options& opts,
1380 Rendezvous** created_rendezvous) const {
1381 if (rendezvous_factory_) {
1382 return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1383 } else {
1384 return errors::FailedPrecondition(
1385 "The caller does not provide a rendezvous and "
1386 "ProcessFunctionLibraryRuntime was created without a rendezvous "
1387 "factory.");
1388 }
1389 }
1390
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1391 void ProcessFunctionLibraryRuntime::Run(
1392 const FunctionLibraryRuntime::Options& opts,
1393 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1394 std::vector<Tensor>* rets,
1395 FunctionLibraryRuntime::DoneCallback done) const {
1396 FunctionLibraryRuntime::Options new_opts = opts;
1397 Rendezvous* created_rendezvous = nullptr;
1398 if (!opts.rendezvous) {
1399 Status s = CreateRendezvous(opts, &created_rendezvous);
1400 if (!s.ok()) {
1401 done(s);
1402 return;
1403 }
1404 new_opts.rendezvous = created_rendezvous;
1405 new_opts.create_rendezvous = false;
1406 }
1407
1408 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1409 done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1410 new_opts.step_id, created_rendezvous);
1411 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1412 done = [rets, function_rets, done = std::move(done)](const Status& s) {
1413 Status status = s;
1414 if (status.ok()) {
1415 for (const auto& ret : *function_rets) {
1416 if (ret.index() == 0) {
1417 rets->push_back(absl::get<Tensor>(ret));
1418 } else {
1419 status.Update(errors::Internal(
1420 "Expect a Tensor as a function output but got a TensorShape."));
1421 break;
1422 }
1423 }
1424 }
1425 delete function_rets;
1426 done(status);
1427 };
1428 bool multi_device;
1429 {
1430 tf_shared_lock l(mu_);
1431 multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
1432 }
1433 if (multi_device) {
1434 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1435 InternalArgs* comp_args) -> Status {
1436 // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1437 for (const auto& it : comp_data.arg_indices) {
1438 if (it.index >= args.size()) {
1439 return errors::InvalidArgument(
1440 "index ", it.index, " is out of range [0, ", args.size(), ")");
1441 }
1442 if (it.sub_index >= 0) {
1443 const Tensor& t = args[it.index];
1444 if (t.dtype() != DT_RESOURCE) {
1445 return errors::InvalidArgument("Got unexpected sub_index ",
1446 it.sub_index, " for argument ",
1447 it.index);
1448 }
1449 const auto& handles = t.flat<ResourceHandle>();
1450 if (it.sub_index >= handles.size()) {
1451 return errors::InvalidArgument(
1452 "Sub_index ", it.sub_index, "is out of range [0,",
1453 handles.size(), ") for argument ", it.index);
1454 }
1455 comp_args->args.push_back(Tensor(handles(it.sub_index)));
1456 } else {
1457 comp_args->args.push_back(args[it.index]);
1458 }
1459 }
1460 return Status::OK();
1461 };
1462 return RunMultiDevice(new_opts, handle, function_rets, cleanup_items,
1463 std::move(done), std::move(get_component_args));
1464 }
1465 std::vector<FunctionArg> local_args;
1466 for (const auto& tensor : args) {
1467 local_args.push_back(tensor);
1468 }
1469 RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1470 std::move(done));
1471 }
1472
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1473 void ProcessFunctionLibraryRuntime::RunInternal(
1474 const FunctionLibraryRuntime::Options& opts,
1475 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1476 std::vector<FunctionRet>* rets,
1477 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1478 FunctionLibraryRuntime::DoneCallback done) const {
1479 FunctionLibraryRuntime* flr = nullptr;
1480 string target_device;
1481 FunctionLibraryRuntime::LocalHandle local_handle;
1482 {
1483 tf_shared_lock l(mu_);
1484 auto iter = function_data_.find(handle);
1485 if (iter == function_data_.end()) {
1486 done(errors::NotFound("Handle: ", handle, " not found."));
1487 return;
1488 }
1489 FunctionData* function_data = iter->second.get();
1490 target_device = function_data->target_device();
1491 local_handle = function_data->local_handle();
1492 }
1493
1494 if (!opts.remote_execution) {
1495 done(
1496 errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1497 "only be called for multi-device functions or "
1498 "for remote execution."));
1499 return;
1500 }
1501
1502 flr = GetFLR(target_device);
1503 if (flr != nullptr) {
1504 auto rendezvous = opts.rendezvous;
1505 string source_device = opts.source_device;
1506 DeviceContext* device_context;
1507 Status s = GetDeviceContext(source_device, &device_context);
1508 if (!s.ok()) {
1509 done(s);
1510 return;
1511 }
1512 int64 src_incarnation, target_incarnation;
1513 s = GetDeviceIncarnation(source_device, &src_incarnation);
1514 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1515 if (!s.ok()) {
1516 done(s);
1517 return;
1518 }
1519
1520 std::vector<Tensor> local_args = GetLocalArgs(args);
1521
1522 // Send the args over to the target device.
1523 s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1524 local_args, device_context, opts.args_alloc_attrs,
1525 rendezvous);
1526 if (!s.ok()) {
1527 done(s);
1528 return;
1529 }
1530 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1531 opts.rets_alloc_attrs;
1532 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1533 flr->Run(opts, handle, local_args, remote_rets,
1534 [source_device, target_device, target_incarnation, rendezvous,
1535 device_context, rets_alloc_attrs, remote_rets, rets,
1536 done = std::move(done)](const Status& status) mutable {
1537 if (!status.ok()) {
1538 delete remote_rets;
1539 done(status);
1540 return;
1541 }
1542 int64 num_returns = remote_rets->size();
1543 delete remote_rets;
1544 // Now receive the return values from the target.
1545 std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1546 ReceiveTensorsAsync(target_device, source_device, "ret_",
1547 target_incarnation, num_returns,
1548 device_context, rets_alloc_attrs, rendezvous,
1549 recv_tensors,
1550 TensorsToFunctionRetsDoneCallback(
1551 rets, recv_tensors, std::move(done)));
1552 });
1553 return;
1554 }
1555 if (parent_ != nullptr) {
1556 auto cleanup_item = absl::make_unique<CleanUpItem>();
1557 cleanup_item->device = target_device;
1558 cleanup_item->step_id = opts.step_id;
1559 cleanup_item->local_handle = local_handle;
1560 cleanup_items->emplace_back(std::move(cleanup_item));
1561 parent_->Run(opts, local_handle, args, rets, std::move(done));
1562 return;
1563 }
1564 done(errors::Internal("Could not find device"));
1565 }
1566
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1567 void ProcessFunctionLibraryRuntime::Run(
1568 const FunctionLibraryRuntime::Options& opts,
1569 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1570 FunctionLibraryRuntime::DoneCallback done) const {
1571 std::vector<Tensor> args;
1572 args.reserve(frame->num_args());
1573 for (size_t i = 0; i < frame->num_args(); ++i) {
1574 const Tensor* arg;
1575 Status s = frame->GetArg(i, &arg);
1576 args.emplace_back(*arg);
1577 if (!s.ok()) {
1578 done(s);
1579 }
1580 }
1581 std::vector<Tensor>* rets = new std::vector<Tensor>;
1582 rets->reserve(frame->num_retvals());
1583
1584 Run(opts, handle, args, rets,
1585
1586 [frame, rets, done = std::move(done)](const Status& status) {
1587 std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1588
1589 if (!status.ok()) {
1590 done(status);
1591 return;
1592 }
1593
1594 if (rets->size() != frame->num_retvals()) {
1595 done(errors::Internal(
1596 "Number of return values from function (", rets->size(),
1597 ") did not match expected number of return values (",
1598 frame->num_retvals(), ")."));
1599 return;
1600 }
1601
1602 for (size_t i = 0; i < frame->num_retvals(); ++i) {
1603 Status s = frame->SetRetval(i, (*rets)[i]);
1604 if (!s.ok()) {
1605 done(s);
1606 return;
1607 }
1608 }
1609 done(Status::OK());
1610 });
1611 }
1612
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1613 Status ProcessFunctionLibraryRuntime::RunSync(
1614 const FunctionLibraryRuntime::Options& opts,
1615 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1616 std::vector<Tensor>* rets) const {
1617 Notification n;
1618 Status s;
1619 Run(opts, handle, args, rets, [&n, &s](const Status& status) {
1620 s.Update(status);
1621 n.Notify();
1622 });
1623 n.WaitForNotification();
1624 return s;
1625 }
1626
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1627 Status ProcessFunctionLibraryRuntime::RunSync(
1628 const FunctionLibraryRuntime::Options& opts,
1629 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1630 Notification n;
1631 Status s;
1632 Run(opts, handle, frame, [&n, &s](const Status& status) {
1633 s.Update(status);
1634 n.Notify();
1635 });
1636 n.WaitForNotification();
1637 return s;
1638 }
1639
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1640 void ProcessFunctionLibraryRuntime::Run(
1641 const FunctionLibraryRuntime::Options& opts,
1642 FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1643 std::vector<FunctionRet>* rets,
1644 FunctionLibraryRuntime::DoneCallback done) const {
1645 bool has_remote_outputs = false;
1646 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1647 if (data != nullptr) {
1648 has_remote_outputs = data->has_remote_outputs;
1649 }
1650 if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1651 const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1652 std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1653 return Run(
1654 opts, handle, local_inputs, tensor_rets,
1655 TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1656 }
1657
1658 FunctionLibraryRuntime::Options new_opts = opts;
1659 Rendezvous* created_rendezvous = nullptr;
1660 if (!opts.rendezvous) {
1661 Status s = CreateRendezvous(opts, &created_rendezvous);
1662 if (!s.ok()) {
1663 done(s);
1664 return;
1665 }
1666 new_opts.rendezvous = created_rendezvous;
1667 new_opts.create_rendezvous = false;
1668 }
1669
1670 #if defined(IS_MOBILE_PLATFORM)
1671 done(errors::Unimplemented(
1672 "Remote inputs are not available on mobile devices."));
1673 return;
1674 #else // !IS_MOBILE_PLATFORM
1675 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1676 done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1677 created_rendezvous);
1678
1679 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1680 InternalArgs* comp_args) -> Status {
1681 for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1682 const FunctionArgIndex index = comp_data.arg_indices.at(i);
1683 Tensor tensor;
1684 if (args.GetLocalArg(index, &tensor).ok()) {
1685 comp_args->args.push_back(std::move(tensor));
1686 } else {
1687 eager::RemoteTensorHandle remote_handle;
1688 TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1689 comp_args->remote_args.emplace_back(
1690 absl::make_unique<eager::RemoteTensorHandle>(
1691 std::move(remote_handle)));
1692 comp_args->args.push_back(comp_args->remote_args.back().get());
1693 }
1694 }
1695 return Status::OK();
1696 };
1697 return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
1698 std::move(get_component_args));
1699 #endif // !IS_MOBILE_PLATFORM
1700 }
1701
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const1702 void ProcessFunctionLibraryRuntime::CleanUp(
1703 std::vector<std::unique_ptr<CleanUpItem>>* items,
1704 FunctionLibraryRuntime::DoneCallback done) const {
1705 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1706 for (auto& item : *items) {
1707 refcounted_done->Ref();
1708 auto* flr = GetFLR(item->device);
1709 if (flr != nullptr) {
1710 // TODO(fishx): cleanup state for local execution.
1711 refcounted_done->UpdateStatus(
1712 errors::Internal("Cleanup items shouldn't contain local item."));
1713 refcounted_done->Unref();
1714 } else if (parent_ != nullptr) {
1715 parent_->CleanUp(item->step_id, item->local_handle,
1716 [refcounted_done](const Status& status) {
1717 if (!status.ok()) {
1718 refcounted_done->UpdateStatus(status);
1719 }
1720 // refcounted_done is thread-safe
1721 refcounted_done->Unref();
1722 });
1723 } else {
1724 refcounted_done->UpdateStatus(
1725 errors::Internal("Could not find device in cleanup."));
1726 refcounted_done->Unref();
1727 }
1728 }
1729 refcounted_done->Unref();
1730 }
1731
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const1732 Status ProcessFunctionLibraryRuntime::Clone(
1733 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1734 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1735 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1736 bool skip_flib_def) const {
1737 if (skip_flib_def) {
1738 *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(
1739 lib_def_->default_registry(), FunctionDefLibrary{});
1740 } else {
1741 *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
1742 }
1743 *out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
1744 device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
1745 out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
1746 session_metadata_, rendezvous_factory_);
1747 {
1748 tf_shared_lock l(mu_);
1749 for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
1750 }
1751 return Status::OK();
1752 }
1753
1754 } // namespace tensorflow
1755