1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17 #include <iterator>
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/common_runtime/placer.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/notification.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 #include "tensorflow/core/util/ptr_util.h"
54 #include "tensorflow/core/util/reffed_status_callback.h"
55 #if !defined(IS_MOBILE_PLATFORM)
56 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
57 #endif // IS_MOBILE_PLATFORM
58
59 namespace tensorflow {
60
61 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
62
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)63 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
64 DistributedFunctionLibraryRuntime* parent, const string& function_name,
65 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
66 const FunctionLibraryRuntime::InstantiateOptions& options,
67 FunctionLibraryRuntime::DoneCallback done) {
68 {
69 mutex_lock l(mu_);
70 is_cross_process_ = true;
71 if (init_started_) {
72 init_done_.WaitForNotification();
73 done(init_result_);
74 return;
75 }
76 init_started_ = true;
77 }
78 parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
79 [this, done](const Status& s) {
80 init_done_.Notify();
81 done(s);
82 });
83 }
84
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)85 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
86 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
87 int graph_def_version, const FunctionLibraryDefinition* lib_def,
88 const OptimizerOptions& optimizer_options,
89 thread::ThreadPool* default_thread_pool,
90 DistributedFunctionLibraryRuntime* parent,
91 const SessionMetadata* session_metadata,
92 Rendezvous::Factory rendezvous_factory)
93 : parent_(parent),
94 env_(env),
95 config_(config ? absl::make_optional(*config) : absl::nullopt),
96 device_mgr_(device_mgr),
97 lib_def_(lib_def),
98 default_thread_pool_(default_thread_pool),
99 flr_map_(new std::unordered_map<Device*,
100 std::unique_ptr<FunctionLibraryRuntime>>),
101 next_handle_(0),
102 session_metadata_(session_metadata),
103 rendezvous_factory_(std::move(rendezvous_factory)),
104 optimizer_options_(optimizer_options),
105 graph_def_version_(graph_def_version) {
106 if (device_mgr == nullptr) {
107 (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
108 nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
109 graph_def_version, lib_def_, default_thread_pool, optimizer_options,
110 session_metadata_, this);
111 return;
112 }
113 InitializeDeviceAndFlr();
114 }
115
116 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)117 Status ProcessFunctionLibraryRuntime::SendTensors(
118 const string& source_device, const string& target_device,
119 const string& key_prefix, int64_t src_incarnation,
120 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
121 const std::vector<AllocatorAttributes>& alloc_attrs,
122 RendezvousInterface* rendezvous) {
123 std::vector<string> keys;
124 for (int i = 0; i < tensors_to_send.size(); ++i) {
125 string name = strings::StrCat(key_prefix, i);
126 string key = Rendezvous::CreateKey(source_device, src_incarnation,
127 target_device, name, FrameAndIter(0, 0));
128 keys.push_back(key);
129 }
130 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
131 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
132 return Status::OK();
133 }
134
135 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,int64_t num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)136 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
137 const string& source_device, const string& target_device,
138 const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
139 DeviceContext* device_context,
140 const std::vector<AllocatorAttributes>& alloc_attrs,
141 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
142 StatusCallback done) {
143 std::vector<string> keys;
144 for (int64_t i = 0; i < num_tensors; ++i) {
145 string name = strings::StrCat(key_prefix, i);
146 string key = Rendezvous::CreateKey(source_device, src_incarnation,
147 target_device, name, FrameAndIter(0, 0));
148 keys.push_back(key);
149 }
150 RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
151 received_tensors, std::move(done));
152 }
153
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)154 Status ProcessFunctionLibraryRuntime::GetRetTypes(
155 FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
156 FunctionLibraryRuntime* flr = nullptr;
157 {
158 tf_shared_lock l(mu_);
159 auto miter = mdevice_data_.find(h);
160 if (miter != mdevice_data_.end()) {
161 *ret_types = miter->second->ret_types_;
162 return Status::OK();
163 }
164 auto fiter = function_data_.find(h);
165 if (fiter != function_data_.end()) {
166 flr = GetFLR(fiter->second->target_device());
167 }
168 }
169 if (flr != nullptr) {
170 return flr->GetRetTypes(h, ret_types);
171 }
172 return errors::InvalidArgument("Handle ", h, " not found.");
173 }
174
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const175 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
176 const string& device_name, int64* incarnation) const {
177 FunctionLibraryRuntime* flr = GetFLR(device_name);
178 if (flr == nullptr) {
179 return errors::InvalidArgument("Device name: ", device_name, " not found.");
180 }
181 *incarnation = flr->device()->attributes().incarnation();
182 return Status::OK();
183 }
184
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const185 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
186 const string& device_name, DeviceContext** device_context) const {
187 *device_context = nullptr;
188 FunctionLibraryRuntime* flr = GetFLR(device_name);
189 if (flr == nullptr) {
190 return errors::InvalidArgument("Device name: ", device_name, " not found.");
191 }
192 Device* device = flr->device();
193 string device_type = device->parsed_name().type;
194 if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
195 // "TPU_SYSTEM" indicates that `device` is a CPU.
196 return Status::OK();
197 }
198
199 if (device->IsRemoteCallAllowed()) {
200 auto* dev_info = flr->device()->tensorflow_gpu_device_info();
201 if (dev_info) {
202 *device_context = dev_info->default_context;
203 return Status::OK();
204 }
205 }
206
207 return errors::Internal("Device type: ", device_type,
208 " is currently unsupported for remote ",
209 "function executions");
210 }
211
InitializeDeviceAndFlr()212 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
213 DeviceMgr const* all_devices = device_mgr_;
214 if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
215 all_devices = parent_->remote_device_mgr();
216 }
217
218 mutex_lock l(mu_);
219 device_set_ = std::make_shared<DeviceSet>();
220 for (auto d : all_devices->ListDevices()) {
221 device_set_->AddDevice(d);
222 }
223 for (Device* d : device_mgr_->ListDevices()) {
224 if ((*flr_map_)[d] == nullptr) {
225 (*flr_map_)[d] = NewFunctionLibraryRuntime(
226 device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
227 graph_def_version_, lib_def_, default_thread_pool_,
228 optimizer_options_, session_metadata_, this);
229 }
230 }
231 }
232
GetFLR(const string & device_name) const233 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
234 const string& device_name) const {
235 Device* device = nullptr;
236 if (device_name != kDefaultFLRDevice) {
237 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
238 VLOG(4) << "Could not find device: " << device_name;
239 return nullptr;
240 }
241 }
242 const auto& iter = flr_map_->find(device);
243 if (iter == flr_map_->end()) {
244 VLOG(1) << "Could not find device: " << device_name
245 << "in the local process.";
246 return nullptr;
247 }
248 return iter->second.get();
249 }
250
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)251 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
252 const string& function_key, const string& device_name,
253 FunctionLibraryRuntime::LocalHandle local_handle) {
254 mutex_lock l(mu_);
255 return AddHandleLocked(function_key, device_name, local_handle);
256 }
257
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)258 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
259 const string& function_key, const string& device_name,
260 FunctionLibraryRuntime::LocalHandle local_handle) {
261 auto h = next_handle_;
262 function_data_[h] =
263 absl::make_unique<FunctionData>(device_name, local_handle, function_key);
264 table_[function_key] = h;
265 next_handle_++;
266 return h;
267 }
268
269 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)270 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
271 std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
272 mutex_lock l(mu_);
273 auto h = next_handle_;
274 mdevice_data_[h] = std::move(data);
275 table_[function_key] = h;
276 next_handle_++;
277 return h;
278 }
279
GetHandle(const string & function_key) const280 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
281 const string& function_key) const {
282 tf_shared_lock l(mu_);
283 return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
284 }
285
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const286 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
287 const string& device_name, FunctionLibraryRuntime::Handle handle) const {
288 return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
289 }
290
291 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const292 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
293 const string& device_name, FunctionLibraryRuntime::Handle handle,
294 bool include_multi_device) const {
295 tf_shared_lock l(mu_);
296
297 auto miter = mdevice_data_.find(handle);
298 if (miter != mdevice_data_.end()) {
299 if (!include_multi_device) return kInvalidLocalHandle;
300
301 const MultiDeviceFunctionData& data = *miter->second;
302 if (data.glue_.size() != 1) return kInvalidLocalHandle;
303
304 const auto& pair = *data.glue_.begin();
305 const string& func_device_name = pair.first;
306 const ComponentFunctionData& component_data = pair.second;
307 if (func_device_name != device_name) return kInvalidLocalHandle;
308
309 // Replace the given handle with the handle for the single component
310 // function.
311 handle = component_data.handle;
312 }
313
314 auto iter = function_data_.find(handle);
315 if (iter == function_data_.end()) {
316 return kInvalidLocalHandle;
317 }
318 FunctionData* function_data = iter->second.get();
319 if (function_data->target_device() != device_name) {
320 return kInvalidLocalHandle;
321 }
322 return function_data->local_handle();
323 }
324
GetDeviceName(FunctionLibraryRuntime::Handle handle) const325 string ProcessFunctionLibraryRuntime::GetDeviceName(
326 FunctionLibraryRuntime::Handle handle) const {
327 tf_shared_lock l(mu_);
328 auto iter = function_data_.find(handle);
329 CHECK(iter != function_data_.end());
330 FunctionData* function_data = iter->second.get();
331 return function_data->target_device();
332 }
333
334 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const335 ProcessFunctionLibraryRuntime::IsMultiDevice(
336 FunctionLibraryRuntime::Handle handle) const {
337 tf_shared_lock l(mu_);
338 const auto& it = mdevice_data_.find(handle);
339 if (it != mdevice_data_.end()) {
340 return it->second.get();
341 }
342 return nullptr;
343 }
344
345 namespace {
346 // Sets `group` to the first colocation group specified in `node`. If no
347 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)348 void GetColocationGroup(const Node* node, string* group) {
349 // We hoist the conversion from C-style string literal to string here,
350 // so that we can avoid the many repeated calls to strlen().
351 static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
352 const AttrValue* attr_value =
353 node->attrs().Find(kColocationAttrNameStringPiece);
354 if (attr_value != nullptr && attr_value->has_list() &&
355 attr_value->list().s_size() > 0) {
356 *group = attr_value->list().s(0);
357 }
358 }
359
AssignedOrRequestedDeviceName(const Node & node)360 const string* AssignedOrRequestedDeviceName(const Node& node) {
361 if (node.has_assigned_device_name()) {
362 return &node.assigned_device_name();
363 }
364 return &node.requested_device();
365 }
366
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)367 Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>&
368 input_resource_dtypes_and_shapes,
369 const std::vector<Node*>& arg_nodes) {
370 for (Node* n : arg_nodes) {
371 int index;
372 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
373 DataType dtype;
374 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
375 if (dtype == DT_RESOURCE) {
376 auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
377 if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
378 AttrValue dtype_attr_value;
379 dtype_attr_value.mutable_list()->add_type(
380 dtype_and_shape_iter->second.dtype);
381 n->AddAttr("_handle_dtypes", dtype_attr_value);
382 TensorShapeProto shape_proto;
383 dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
384 AttrValue shape_attr_value;
385 *shape_attr_value.mutable_list()->add_shape() = shape_proto;
386 n->AddAttr("_handle_shapes", shape_attr_value);
387 }
388 }
389 }
390 return Status::OK();
391 }
392
393 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)394 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
395 std::vector<Tensor> tensors;
396 for (const auto& arg : args) {
397 if (arg.index() == 0) {
398 tensors.push_back(absl::get<Tensor>(arg));
399 }
400 }
401 return tensors;
402 }
403
404 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)405 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
406 std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
407 FunctionLibraryRuntime::DoneCallback done) {
408 return [rets, tensors, done = std::move(done)](const Status& s) {
409 if (s.ok()) {
410 for (const auto& t : *tensors) {
411 rets->push_back(t);
412 }
413 }
414 delete tensors;
415 done(s);
416 };
417 }
418
419 } // anonymous namespace
420
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,const FunctionLibraryDefinition * lib_def,Device * default_device)421 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
422 const std::vector<string>& input_devices,
423 const std::vector<string>& output_devices, const DeviceSet& device_set,
424 const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
425 const FunctionLibraryDefinition* lib_def, Device* default_device) {
426 // If output_devices are not specified, we want to set the output device
427 // based on the device of the output producing node. The output producing
428 // node can be an arg node because functions can simply return their
429 // arguments. To make sure that the output producing nodes have assigned
430 // devices, we assign them to arguments first.
431 for (Node* node : arg_nodes) {
432 const AttrValue* attr_value;
433 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
434 int64_t index = attr_value->i();
435 node->set_assigned_device_name(input_devices[index]);
436 }
437
438 for (Node* node : ret_nodes) {
439 if (output_devices.empty()) {
440 DataType dtype;
441 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
442
443 VLOG(3) << "Trying to determine device for node " << node->name()
444 << "[T=" << DataTypeString(dtype) << "]";
445
446 // If output_devices are empty, the node producing retval
447 // must have explicitly assigned device or a colocation constraint
448 // to a node with explicitly assigned device.
449 for (const auto& it : node->in_edges()) {
450 if (it->IsControlEdge()) continue;
451
452 Node* src_node = it->src();
453 const string* src_device = AssignedOrRequestedDeviceName(*src_node);
454 string colocation_group = "";
455 GetColocationGroup(src_node, &colocation_group);
456 VLOG(3) << "Considering src: " << src_node->name()
457 << " src_device: " << *src_device
458 << " colo group: " << colocation_group;
459 while (src_device->empty() && colocation_group.empty() &&
460 src_node->IsIdentity()) {
461 // Only follows the real data input of Identity, not control edges.
462 Node* input_node;
463 TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
464 src_node = input_node;
465
466 src_device = AssignedOrRequestedDeviceName(*src_node);
467 GetColocationGroup(src_node, &colocation_group);
468 VLOG(3) << "Considering src: " << src_node->name()
469 << " src_device: " << *src_device
470 << " colo group: " << colocation_group;
471 }
472
473 // If resource is produced by a function call node, we can't trust
474 // source node device assignment, because multi-device functions can
475 // return resource placed on multiple devices. In such case we leave
476 // retval device assignment empty, and rely on placer to infer correct
477 // assignment based on actual output device.
478 const bool can_use_src_node_device =
479 !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node));
480
481 if (!colocation_group.empty()) {
482 AttrValue::ListValue colo_attr;
483 colo_attr.add_s(colocation_group);
484 std::vector<string> colo_slice = {colocation_group};
485 node->AddAttr(kColocationAttrName, colo_slice);
486 } else if (!src_device->empty() && can_use_src_node_device) {
487 // src_device can be a partially specified device. Find the
488 // matching device in the device_set.
489 DeviceNameUtils::ParsedName parsed;
490 if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
491 return errors::InvalidArgument(
492 "Failed to parse explicit device specification ", *src_device);
493 }
494 std::vector<Device*> matching_devices;
495 device_set.FindMatchingDevices(parsed, &matching_devices);
496 if (matching_devices.empty()) {
497 if (default_device != nullptr) {
498 matching_devices.push_back(default_device);
499 } else {
500 return errors::InvalidArgument(
501 "Unable to find any devices for spec ", *src_device);
502 }
503 } else if (matching_devices.size() != 1) {
504 bool on_same_task = true;
505 for (int i = 1; i < matching_devices.size(); ++i) {
506 if (!DeviceNameUtils::IsSameAddressSpace(
507 matching_devices.at(0)->parsed_name(),
508 matching_devices.at(i)->parsed_name())) {
509 on_same_task = false;
510 break;
511 }
512 }
513 // If the src node of an output is assigned to a address space (e.g.
514 // py_func), rely on placer to assign a device to the output.
515 if (on_same_task) {
516 continue;
517 }
518 // Compare with default_device if it has a narrower scope matching
519 // requested device.
520 int colocated_on_default_device = 0;
521 for (int i = 0; i < matching_devices.size(); ++i) {
522 if (DeviceNameUtils::IsSameAddressSpace(
523 default_device->parsed_name(),
524 matching_devices.at(i)->parsed_name())) {
525 colocated_on_default_device++;
526 }
527 }
528 // Continue to raise error if multiple colocated devices are
529 // found.
530 if (colocated_on_default_device == 1) {
531 continue;
532 }
533
534 // Convert a vector of devices to a string.
535 // Using absl::StrJoin did not work in Android builds.
536 string devices = "[";
537 for (Device* device : matching_devices) {
538 devices.append(device->name());
539 devices.append(", ");
540 }
541 if (devices.size() > 2) {
542 devices.resize(devices.size() - 2);
543 }
544 devices.append("]");
545
546 return errors::InvalidArgument(
547 *src_device,
548 "When FunctionLibraryRuntime::Options.output_devices are "
549 "not specified for a multi-device function, the device "
550 "specification on the output node must match exactly one "
551 "device. Matched devices are ",
552 devices);
553 }
554 VLOG(3) << "Setting output device to " << matching_devices[0]->name()
555 << " for node " << SummarizeNode(*node);
556 node->set_assigned_device_name(matching_devices[0]->name());
557 } else if (!src_device->empty() && !can_use_src_node_device) {
558 VLOG(3) << "Did not set device for a resource output node "
559 << SummarizeNode(*node);
560 }
561 }
562 } else {
563 const AttrValue* attr_value;
564 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
565 int64_t index = attr_value->i();
566 // output_devices size is checked in InstantiateMultiDevice
567 DCHECK_GT(output_devices.size(), index);
568 VLOG(3) << "Setting output device to " << output_devices[index]
569 << " for return at index " << index;
570 node->set_assigned_device_name(output_devices[index]);
571 }
572 }
573 return Status::OK();
574 }
575
576 namespace {
577
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)578 Status ValidateNoListArguments(
579 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
580 const string& function_name) {
581 for (const OpDef::ArgDef& arg : args) {
582 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
583 return errors::InvalidArgument(
584 "Function ", function_name, " has an ", arg_type, " named \"",
585 arg.name(),
586 "\" that is a list of tensors."
587 " Multi-device functions support only single-tensor inputs "
588 " and outputs");
589 }
590 }
591 return Status::OK();
592 }
593
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)594 Status ValidateMultiDeviceOptions(
595 const FunctionDef& fdef,
596 const FunctionLibraryRuntime::InstantiateOptions& options) {
597 const OpDef& signature = fdef.signature();
598 // Multi-device functions currently do not support list inputs or outputs.
599 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
600 signature.name()));
601 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
602 signature.name()));
603 if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
604 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
605 return errors::Unimplemented(
606 "Function '", signature.name(), "' has `",
607 FunctionLibraryDefinition::kIntsOnDeviceAttr,
608 "` attribute set. This attribute is not currently supported by "
609 "multi-device functions.");
610 }
611 if (options.input_devices.size() != signature.input_arg_size()) {
612 return errors::InvalidArgument(
613 "InstantiateOptions.input_devices must have the same length "
614 "as the number of arguments: input_devices length = ",
615 options.input_devices.size(),
616 " number of arguments = ", signature.input_arg_size());
617 }
618 if (!options.output_devices.empty() &&
619 options.output_devices.size() != signature.output_arg_size()) {
620 return errors::InvalidArgument(
621 "InstantiateOptions.output_devices must either be empty or have the "
622 "same length as the number of arguments: output_devices length = ",
623 options.output_devices.size(),
624 " number of arguments = ", signature.output_arg_size());
625 }
626 return Status::OK();
627 }
628
629 } // anonymous namespace
630
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)631 Status GetGraphAndArgRets(
632 const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
633 const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
634 std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
635 std::vector<string>* ret_node_names, DataTypeVector* ret_types,
636 std::vector<string>* control_ret_node_names) {
637 std::unique_ptr<FunctionBody> fbody;
638 // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
639 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
640 if (!fbody) {
641 LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
642 return errors::Internal("Failed to construct FunctionBody for ",
643 function_name);
644 }
645 *graph = std::unique_ptr<Graph>(fbody->graph);
646 arg_nodes->reserve(fbody->arg_nodes.size());
647 std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
648 std::back_inserter(*arg_nodes));
649 ret_nodes->reserve(fbody->ret_nodes.size());
650 std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
651 std::back_inserter(*ret_nodes));
652 fbody->graph = nullptr;
653 ret_node_names->reserve(fbody->ret_nodes.size());
654 for (const Node* node : fbody->ret_nodes) {
655 ret_node_names->push_back(node->name());
656 }
657 for (const auto& ret_type : fbody->ret_types) {
658 ret_types->push_back(ret_type);
659 }
660 control_ret_node_names->reserve(fbody->control_ret_nodes.size());
661 for (const Node* node : fbody->control_ret_nodes) {
662 control_ret_node_names->push_back(node->name());
663 }
664 return Status::OK();
665 }
666
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)667 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
668 const string& function_name, AttrSlice attrs,
669 const FunctionLibraryRuntime::InstantiateOptions& options,
670 FunctionLibraryRuntime::Handle* handle) {
671 // Check if this function has already been instantiated.
672 const string& function_key = Canonicalize(function_name, attrs, options);
673
674 {
675 mutex_lock l(mu_);
676 const auto& it = table_.find(function_key);
677 if (it != table_.end()) {
678 *handle = it->second;
679 ++mdevice_data_[*handle]->instantiation_counter_;
680 return Status::OK();
681 }
682 }
683
684 VLOG(1) << "Instantiating MultiDevice function \"" << function_name
685 << "\" on default device \"" << options.target << "\"";
686 if (VLOG_IS_ON(3)) {
687 int index = 0;
688 VLOG(3) << "Requested input devices:";
689 for (const string& device : options.input_devices) {
690 VLOG(3) << " [input " << index++ << "] " << device;
691 }
692 index = 0;
693 VLOG(3) << "Requested output devices:";
694 for (const string& device : options.output_devices) {
695 VLOG(3) << " [output " << index++ << "] " << device;
696 }
697 }
698
699 const FunctionLibraryDefinition* lib_def =
700 options.lib_def == nullptr ? lib_def_ : options.lib_def;
701
702 const FunctionDef* fdef = lib_def->Find(function_name);
703 if (fdef == nullptr) {
704 return errors::InvalidArgument("Failed to find function \"", function_name,
705 "\" in function library: ", lib_def);
706 }
707
708 TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
709
710 std::unique_ptr<Graph> graph;
711 std::vector<Node*> arg_nodes, ret_nodes;
712 std::vector<string> ret_node_names;
713 DataTypeVector ret_types;
714 std::vector<string> control_ret_node_names;
715
716 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
717 function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
718 &ret_node_names, &ret_types, &control_ret_node_names));
719
720 GraphDef graph_def;
721 graph->ToGraphDef(&graph_def);
722 FunctionLibraryDefinition reachable_lib_def =
723 lib_def->ReachableDefinitions(graph_def);
724 *graph_def.mutable_library() = reachable_lib_def.ToProto();
725 if (options.graph_collector != nullptr) {
726 options.graph_collector->CollectRawGraph(graph_def);
727 }
728
729 Device* default_device = nullptr;
730 if (options.default_device_to_target && !options.target.empty()) {
731 // Make the `target` device the default device if nothing else is hard
732 // coded. This allows the same function definition to be specialized to
733 // different devices depending on the `PartitionedCallOp` device.
734 FunctionLibraryRuntime* flr = GetFLR(options.target);
735 if (flr == nullptr) {
736 return errors::InvalidArgument(
737 "Cannot instantiate multi-device function with target device ",
738 options.target);
739 }
740 default_device = flr->device();
741 }
742 const std::shared_ptr<DeviceSet> dev_set = device_set();
743
744 TF_RETURN_IF_ERROR(
745 SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
746 TF_RETURN_IF_ERROR(PinArgsAndRets(
747 options.input_devices, options.output_devices, *dev_set, arg_nodes,
748 ret_nodes, lib_def_,
749 options.config_proto.allow_soft_placement() ? default_device : nullptr));
750
751 auto data = absl::make_unique<MultiDeviceFunctionData>(
752 function_name, function_key, ret_node_names.size(),
753 std::move(reachable_lib_def), std::move(ret_types));
754
755 // The runtime shouldn't depend on duplication between the function library
756 // owned by the graph and the one owned by the runtime. To ensure this, for
757 // now we ensure that the graph function library is empty and the runtime
758 // library receives the query from LookUps on the graph function library.
759 graph->mutable_flib_def()->set_default_registry(&data->lib_def_);
760 graph->mutable_flib_def()->Clear();
761
762 // Do not run function/graph optimization passes for component functions,
763 // since they have already processed the main function.
764 const bool should_run_optimization_passes = !options.is_component_function;
765 if (!should_run_optimization_passes) {
766 VLOG(1) << "Skipping function/graph optimization passes when instantiating "
767 "component function "
768 << function_name;
769 }
770
771 // Mapping from a function body node name to the control output name.
772 std::unordered_map<string, string> node_name_to_control_ret;
773
774 bool control_rets_updated = false;
775 if (should_run_optimization_passes) {
776 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
777 *dev_set, options.config_proto, &graph, &data->lib_def_,
778 &control_ret_node_names, &control_rets_updated));
779 }
780
781 if (control_rets_updated) {
782 // Function graph pass may have resulted in different nodes/node names for
783 // control rets.
784 for (const auto& control_ret : control_ret_node_names) {
785 node_name_to_control_ret.emplace(control_ret, control_ret);
786 }
787 } else {
788 for (const auto& control_ret : fdef->control_ret()) {
789 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
790 }
791 }
792
793 GraphOptimizationPassOptions optimization_options;
794 // TODO(iga): Thread other relevant options from SessionOptions.
795 SessionOptions session_options;
796 session_options.env = env_;
797 session_options.config = options.config_proto;
798 optimization_options.session_options = &session_options;
799 optimization_options.graph = &graph;
800 optimization_options.flib_def = &data->lib_def_;
801 optimization_options.device_set = dev_set.get();
802 optimization_options.is_function_graph = true;
803 std::vector<CompositeDevice*> composite_devices;
804 {
805 tf_shared_lock l(mu_);
806 for (auto* d : composite_devices_) composite_devices.push_back(d);
807 }
808 optimization_options.composite_devices = &composite_devices;
809 optimization_options.default_function_device = default_device;
810 optimization_options.function_def = fdef;
811
812 DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
813 if (should_run_optimization_passes) {
814 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
815 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
816 }
817
818 // TODO(b/124993244): Smartly merge options in nested defuns, and raise
819 // exceptions/warnings in case where nested function call options are ignored.
820 DumpGraph("Before calling Placer", graph.get());
821 Placer placer(graph.get(), function_name, optimization_options.flib_def,
822 dev_set.get(), default_device,
823 options.config_proto.allow_soft_placement(),
824 options.config_proto.log_device_placement());
825 TF_RETURN_IF_ERROR(placer.Run());
826
827 DumpGraph("Before running POST_PLACEMENT passes", graph.get());
828 if (should_run_optimization_passes) {
829 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
830 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
831 }
832
833 Device* cpu_device;
834 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
835
836 if (options.optimize_graph_fn) {
837 DumpGraph("Before running graph optimization fn", graph.get());
838 Status status = options.optimize_graph_fn(
839 std::move(ret_node_names), std::move(control_ret_node_names),
840 &data->lib_def_, *dev_set, cpu_device, &graph);
841 if (!status.ok()) {
842 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
843 << status.ToString();
844 }
845 DumpGraph("After optimization", graph.get());
846 }
847
848 DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
849 if (should_run_optimization_passes) {
850 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
851 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
852 }
853
854 // Expand the nodes assigned to a CompositeDevice before graph partition to
855 // avoid generating a subgraph on a virtual device for execution.
856 // This transformation should happen as late as possible, in order to run as
857 // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
858 // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
859 TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
860 options.composite_devices, graph.get()));
861
862 if (options.graph_collector != nullptr) {
863 GraphDef def;
864 graph->ToGraphDef(&def);
865 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
866 options.graph_collector->CollectOptimizedGraph(def);
867 }
868
869 VLOG(4) << "Main function graph to be partitioned:";
870 VLOG(4) << DebugString(graph->ToGraphDefDebug());
871
872 std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
873 TF_RETURN_IF_ERROR(
874 PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
875
876 for (const auto& pair : subgraphs) {
877 DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
878 pair.first, ")"),
879 pair.second.get());
880 }
881 optimization_options.graph = nullptr;
882 optimization_options.device_set = nullptr;
883 optimization_options.partition_graphs = &subgraphs;
884 // Normally POST_PARTITIONING passes are run by distributed workers.
885 // Distributed workers are currently not supported in this code path, so we
886 // run the passes here.
887 if (should_run_optimization_passes) {
888 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
889 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
890 }
891 for (const auto& pair : subgraphs) {
892 const auto* optimized_subgraph = pair.second.get();
893 DumpGraph(
894 strings::StrCat("After all optimization passes (", pair.first, ")"),
895 optimized_subgraph);
896 if (VLOG_IS_ON(1)) {
897 DumpGraphDefToFile(
898 strings::StrCat("pflr_after_all_optimization_passes_",
899 reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
900 pair.first),
901 optimized_subgraph->ToGraphDefDebug());
902 }
903 }
904
905 if (options.graph_collector != nullptr) {
906 for (const auto& pair : subgraphs) {
907 GraphDef def;
908 pair.second->ToGraphDef(&def);
909 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
910 options.graph_collector->CollectPartitionedGraph(def);
911 }
912 }
913
914 // We must preserve control returns in each of the function components,
915 // otherwise after function inlining we might prune side-effectful nodes.
916 const auto control_ret =
917 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
918 const auto it = node_name_to_control_ret.find(n->name());
919 return it != node_name_to_control_ret.end()
920 ? absl::make_optional<string>(it->second)
921 : absl::nullopt;
922 };
923
924 int i = 0;
925 // Generate a random function_name to avoid one function reuse the partition
926 // function instantiated by another function.
927 FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
928 FunctionNameGenerator name_generator(
929 data_lib_def, absl::StrCat(function_name, "_", random::New64()));
930 auto subgraph_size = subgraphs.size();
931 gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
932 BlockingCounter counter(static_cast<int>(subgraph_size));
933 auto runner = [this, subgraph_size](std::function<void()> fn) {
934 // NOTE: Only use thread pool to instantiate sub-function when there are
935 // more than 8 sub-functions. We want to avoid cost of switching thread when
936 // there are only a few sub-functions.
937 if (default_thread_pool_ != nullptr && subgraph_size > 8) {
938 default_thread_pool_->Schedule(fn);
939 } else {
940 fn();
941 }
942 };
943 for (const auto& pair : subgraphs) {
944 Status* status = &instantiate_status[i];
945 string unique_name = name_generator.GetName();
946 ComponentFunctionData* comp_data = &data->glue_[pair.first];
947 runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
948 &control_ret, &options, status, &counter, &data] {
949 const string& target = pair.first;
950
951 const string& device_type =
952 dev_set->FindDeviceByName(target)->device_type();
953 Graph* subgraph = pair.second.get();
954
955 status->Update(UpdateArgAndRetvalMetadata(
956 subgraph, device_type, &comp_data->arg_indices,
957 &comp_data->ret_indices, &comp_data->arg_alloc_attrs,
958 &comp_data->ret_alloc_attrs));
959 if (!status->ok()) {
960 counter.DecrementCount();
961 return;
962 }
963 FunctionDef shard;
964 status->Update(
965 GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
966 if (!status->ok()) {
967 counter.DecrementCount();
968 return;
969 }
970 status->Update(data_lib_def->AddFunctionDef(shard));
971 if (!status->ok()) {
972 counter.DecrementCount();
973 return;
974 }
975 FunctionLibraryRuntime::InstantiateOptions opts;
976 opts.executor_type = options.executor_type;
977 opts.target = target;
978 opts.lib_def = data_lib_def;
979 opts.create_kernels_eagerly = options.create_kernels_eagerly;
980 opts.state_handle = options.state_handle;
981 auto attrs = AttrSlice(&shard.attr());
982 VLOG(1) << "Start instantiating component function " << unique_name
983 << " on device " << target;
984 VLOG(4) << DebugString(shard);
985
986 auto* component_handle = new FunctionLibraryRuntime::Handle;
987 auto done = [this, status, unique_name, comp_data, component_handle,
988 &data, &counter](const Status& s) {
989 status->Update(s);
990
991 VLOG(1) << "Finished instantiating component function " << unique_name
992 << " with handle " << *component_handle << " status: " << s;
993 if (status->ok()) {
994 {
995 mutex_lock l(mu_);
996 if (function_data_[*component_handle]->is_cross_process()) {
997 data->is_cross_process_ = true;
998 }
999 }
1000 comp_data->handle = *component_handle;
1001 }
1002 delete component_handle;
1003 counter.DecrementCount();
1004 };
1005
1006 FunctionLibraryRuntime* flr = GetFLR(opts.target);
1007 if (flr != nullptr) {
1008 // Initialize local function synchronously.
1009 Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
1010 done(s);
1011 } else {
1012 opts.ret_indices = comp_data->ret_indices;
1013 // Initialize remote function asynchronously.
1014 InstantiateRemote(unique_name, attrs, opts, component_handle, done);
1015 }
1016 });
1017 i += 1;
1018 }
1019 counter.Wait();
1020 StatusGroup group;
1021 for (auto& status : instantiate_status) {
1022 group.Update(status);
1023 }
1024 TF_RETURN_IF_ERROR(group.as_summary_status());
1025
1026 *handle = AddMultiDeviceHandle(std::move(data), function_key);
1027 VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1028 << "\" with handle " << *handle;
1029 return Status::OK();
1030 }
1031
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1032 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1033 FunctionLibraryRuntime::Handle handle,
1034 std::vector<Device*>* output_devices) const {
1035 MultiDeviceFunctionData* data = IsMultiDevice(handle);
1036 if (data == nullptr) {
1037 return errors::InvalidArgument(
1038 "Failed for find multi-device function handle ", handle);
1039 }
1040
1041 for (const auto& pair : data->glue_) {
1042 const ComponentFunctionData& comp_data = pair.second;
1043 DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1044 if (comp_data.ret_indices.empty()) {
1045 continue;
1046 }
1047
1048 const string& target = pair.first;
1049 FunctionLibraryRuntime* target_flr = GetFLR(target);
1050 Device* target_device = nullptr;
1051 Device* host = nullptr;
1052 if (target_flr == nullptr) {
1053 if (!data->has_remote_outputs) {
1054 data->has_remote_outputs = true;
1055 }
1056 target_device = device_set()->FindDeviceByName(target);
1057 string remote_host;
1058 TF_RETURN_IF_ERROR(
1059 DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1060 host = device_set()->FindDeviceByName(remote_host);
1061 } else {
1062 target_device = target_flr->device();
1063 }
1064 output_devices->resize(data->num_outputs_);
1065 for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1066 int ret_index = comp_data.ret_indices[j];
1067 if (data->ret_types_[ret_index] == DT_RESOURCE) {
1068 (*output_devices)[ret_index] = target_device;
1069 } else {
1070 (*output_devices)[ret_index] =
1071 comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1072 }
1073 }
1074 }
1075
1076 return Status::OK();
1077 }
1078
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1079 void ProcessFunctionLibraryRuntime::RunMultiDevice(
1080 const FunctionLibraryRuntime::Options& opts,
1081 FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
1082 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1083 FunctionLibraryRuntime::DoneCallback done,
1084 std::function<Status(const ComponentFunctionData& comp_data,
1085 InternalArgs* args)>
1086 get_component_args) const {
1087 if (opts.create_rendezvous) {
1088 // FLR->Run() is the default entry point. It checks for cancellation,
1089 // creates rendezvous, etc.
1090 // Letting create_rendezvous through will do the wrong thing - each
1091 // component function will get a separate rendezvous created by its FLR.
1092 done(
1093 errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
1094 "create_rendezvous=true. Please run the function "
1095 "using FunctionLibraryRuntime::Run"));
1096 return;
1097 }
1098
1099 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1100 if (data == nullptr) {
1101 done(errors::NotFound("Multi-device function handle ", handle,
1102 "not found. Was the function instantiated?"));
1103 return;
1104 }
1105
1106 VLOG(1) << "Running multi-device function " << data->function_name_;
1107 VLOG(4) << " with " << opts.DebugString();
1108
1109 if (data->glue_.empty()) {
1110 // Trivial case where the function body is empty.
1111 done(Status::OK());
1112 return;
1113 }
1114
1115 // Check whether we have the right rendezvous.
1116 if (opts.rendezvous && data->is_cross_process_ &&
1117 !opts.rendezvous->is_cross_process()) {
1118 done(errors::InvalidArgument(
1119 "Running a cross process function ", data->function_name_,
1120 " without an appropriate cross process Rendezvous."));
1121 return;
1122 }
1123
1124 // A locally created cancellation manager, used only when the caller does not
1125 // provide one in argument.
1126 std::shared_ptr<CancellationManager> local_cm;
1127 CancellationManager* cm = opts.cancellation_manager;
1128 if (cm == nullptr) {
1129 local_cm = std::make_shared<CancellationManager>();
1130 cm = local_cm.get();
1131 }
1132
1133 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1134 for (int i = 0; i < data->glue_.size(); ++i) {
1135 refcounted_done->Ref();
1136 }
1137
1138 FunctionLibraryRuntime::Options opts_copy = opts;
1139 for (const auto& pair : data->glue_) {
1140 const string& target = pair.first;
1141 const ComponentFunctionData& comp_data = pair.second;
1142 FunctionLibraryRuntime::Handle handle = pair.second.handle;
1143
1144 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1145 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1146 opts_copy.cancellation_manager = cm;
1147
1148 InternalArgs comp_args;
1149 Status s = get_component_args(comp_data, &comp_args);
1150 if (!s.ok()) {
1151 VLOG(2) << "Failed to get component function arguments: " << s;
1152 refcounted_done->UpdateStatus(s);
1153 refcounted_done->Unref();
1154 cm->StartCancel();
1155 continue;
1156 }
1157 std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1158 rets->resize(data->num_outputs_);
1159
1160 auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1161 cm, local_cm, data, handle,
1162 target](const Status& status) {
1163 if (!status.ok()) {
1164 VLOG(2) << "Component function execution on target " << target
1165 << " from " << data->function_name_ << " with handle " << handle
1166 << " failed: " << status;
1167 const string function_and_msg = strings::StrCat(
1168 errors::FormatFunctionForError(data->function_name_), " ",
1169 status.error_message());
1170 refcounted_done->UpdateStatus(Status(status.code(), function_and_msg));
1171 // Cancel the execution of other component functions.
1172 cm->StartCancel();
1173 } else {
1174 VLOG(2) << "Component function execution on target " << target
1175 << " from " << data->function_name_ << " with handle " << handle
1176 << " succeeded.";
1177 for (int i = 0; i < comp_rets->size(); ++i) {
1178 (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1179 }
1180 }
1181 delete comp_rets;
1182 // refcounted_done is thread-safe
1183 refcounted_done->Unref();
1184 };
1185
1186 FunctionLibraryRuntime* flr = GetFLR(target);
1187 if (flr != nullptr) {
1188 opts_copy.remote_execution = false;
1189 // When target device has private thread pool, use the target device
1190 // runner
1191 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1192 opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1193
1194 VLOG(1) << "Running component function on device " << target << " from "
1195 << data->function_name_ << " with handle " << handle;
1196 VLOG(4) << " with " << opts_copy.DebugString();
1197
1198 std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1199 flr->Run(
1200 opts_copy, handle, GetLocalArgs(comp_args.args), comp_tensor_rets,
1201 TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1202 std::move(component_fn_callback)));
1203 } else {
1204 opts_copy.remote_execution = true;
1205
1206 VLOG(1) << "Running component function on device " << target << " from "
1207 << data->function_name_ << " with handle " << handle;
1208 VLOG(4) << " with " << opts_copy.DebugString();
1209
1210 RunInternal(opts_copy, handle, comp_args.args, comp_rets, cleanup_items,
1211 std::move(component_fn_callback));
1212 }
1213 }
1214 refcounted_done->Unref();
1215 }
1216
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1217 Status ProcessFunctionLibraryRuntime::Instantiate(
1218 const string& function_name, AttrSlice attrs,
1219 const FunctionLibraryRuntime::InstantiateOptions& options,
1220 FunctionLibraryRuntime::Handle* handle) {
1221 if (options.is_multi_device_function) {
1222 return InstantiateMultiDevice(function_name, attrs, options, handle);
1223 }
1224
1225 *handle = kInvalidHandle;
1226 FunctionLibraryRuntime* flr = GetFLR(options.target);
1227 if (flr != nullptr) {
1228 return flr->Instantiate(function_name, attrs, options, handle);
1229 }
1230
1231 Status status;
1232 Notification notification;
1233 InstantiateRemote(function_name, attrs, options, handle,
1234 [&status, ¬ification](const Status& s) {
1235 status = s;
1236 notification.Notify();
1237 });
1238 notification.WaitForNotification();
1239 return status;
1240 }
1241
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1242 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1243 FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1244 tf_shared_lock l(mu_);
1245 const auto& mdevice_it = mdevice_data_.find(handle);
1246 if (mdevice_it != mdevice_data_.end()) {
1247 *is_cross_process = mdevice_it->second->is_cross_process_;
1248 return Status::OK();
1249 }
1250 const auto& it = function_data_.find(handle);
1251 if (it != function_data_.end()) {
1252 *is_cross_process = it->second->is_cross_process();
1253 return Status::OK();
1254 }
1255 return errors::InvalidArgument("Handle ", handle, " not found.");
1256 }
1257
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1258 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1259 const string& function_name, AttrSlice attrs,
1260 const FunctionLibraryRuntime::InstantiateOptions& options,
1261 FunctionLibraryRuntime::Handle* handle,
1262 FunctionLibraryRuntime::DoneCallback done) {
1263 if (parent_ == nullptr) {
1264 done(errors::Internal(
1265 "Currently don't support instantiating functions on device: ",
1266 options.target));
1267 return;
1268 }
1269 auto target = options.target;
1270 VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1271 string function_key = Canonicalize(function_name, attrs, options);
1272 FunctionData* f;
1273 {
1274 mutex_lock l(mu_);
1275 FunctionLibraryRuntime::Handle h =
1276 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1277 if (h == kInvalidHandle || function_data_.count(h) == 0) {
1278 h = AddHandleLocked(function_key, target, kInvalidHandle);
1279 }
1280 f = function_data_[h].get();
1281 *handle = h;
1282 }
1283 f->DistributedInit(
1284 parent_, function_name,
1285 options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1286 [this, function_name, target, handle, done](const Status& s) {
1287 VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1288 << " on: " << target << " with handle: " << *handle
1289 << " (this: " << this << ")";
1290 done(s);
1291 });
1292 }
1293
RemoveHandle(FunctionLibraryRuntime::Handle handle)1294 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1295 FunctionLibraryRuntime::Handle handle) {
1296 mutex_lock l(mu_);
1297 table_.erase(function_data_[handle]->function_key());
1298 function_data_.erase(handle);
1299 return Status::OK();
1300 }
1301
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1302 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1303 FunctionLibraryRuntime::Handle handle) {
1304 std::unique_ptr<MultiDeviceFunctionData> mdata;
1305 {
1306 mutex_lock l(mu_);
1307 auto it = mdevice_data_.find(handle);
1308 --it->second->instantiation_counter_;
1309 if (it->second->instantiation_counter_ != 0) {
1310 return Status::OK();
1311 }
1312 mdata = std::move(it->second);
1313 table_.erase(mdata->function_key_);
1314 mdevice_data_.erase(it);
1315 }
1316
1317 // If we are here we are releasing the last instantiation of `handle`.
1318 // Release all component function handles.
1319 Status overall_status;
1320 for (const auto& it : mdata->glue_) {
1321 const string& device = it.first;
1322 FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1323 FunctionLibraryRuntime* flr = GetFLR(device);
1324 if (flr == nullptr) {
1325 // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1326 // parent is not null.
1327 if (parent_ != nullptr) {
1328 return errors::Unimplemented(
1329 "Releasing a multi-device component handle on a remote device is "
1330 "not yet implemented.");
1331 }
1332 return errors::InvalidArgument(
1333 "Failed to find FunctionLibraryRuntime for device ", device,
1334 " when releasing multi-device function handle ", handle);
1335 }
1336 Status status = flr->ReleaseHandle(flr_handle);
1337 if (!status.ok()) {
1338 overall_status = status;
1339 }
1340 }
1341
1342 return overall_status;
1343 }
1344
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1345 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1346 FunctionLibraryRuntime::Handle handle) {
1347 // Return directly if all function handles has already been released.
1348 if (flr_map_ == nullptr) return Status::OK();
1349
1350 if (IsMultiDevice(handle)) {
1351 return ReleaseMultiDeviceHandle(handle);
1352 }
1353
1354 FunctionLibraryRuntime* flr = nullptr;
1355 string target_device;
1356 {
1357 mutex_lock l(mu_);
1358 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1359 target_device = function_data_[handle]->target_device();
1360 }
1361 flr = GetFLR(target_device);
1362 if (flr != nullptr) {
1363 return flr->ReleaseHandle(handle);
1364 }
1365 return errors::InvalidArgument("Handle not found: ", handle);
1366 }
1367
1368 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64_t step_id,const Rendezvous * created_rendezvous) const1369 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1370 std::vector<std::unique_ptr<CleanUpItem>>* items,
1371 FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
1372 const Rendezvous* created_rendezvous) const {
1373 return [this, items, done = std::move(done), step_id,
1374 created_rendezvous](const Status& status) {
1375 if (created_rendezvous) {
1376 DCHECK(rendezvous_factory_);
1377 created_rendezvous->Unref();
1378 Status s = rendezvous_factory_.CleanUp(step_id);
1379 if (!s.ok()) {
1380 LOG(ERROR) << s;
1381 }
1382 }
1383 auto* local_status = new Status(status);
1384 CleanUp(items, [local_status, done](const Status& cleanup_status) {
1385 local_status->Update(cleanup_status);
1386 done(*local_status);
1387 delete local_status;
1388 });
1389 delete items;
1390 };
1391 }
1392
CreateRendezvous(const FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1393 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1394 const FunctionLibraryRuntime::Options& opts,
1395 Rendezvous** created_rendezvous) const {
1396 if (rendezvous_factory_) {
1397 return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1398 } else {
1399 return errors::FailedPrecondition(
1400 "The caller does not provide a rendezvous and "
1401 "ProcessFunctionLibraryRuntime was created without a rendezvous "
1402 "factory.");
1403 }
1404 }
1405
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1406 void ProcessFunctionLibraryRuntime::Run(
1407 const FunctionLibraryRuntime::Options& opts,
1408 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1409 std::vector<Tensor>* rets,
1410 FunctionLibraryRuntime::DoneCallback done) const {
1411 FunctionLibraryRuntime::Options new_opts = opts;
1412 Rendezvous* created_rendezvous = nullptr;
1413 if (!opts.rendezvous) {
1414 Status s = CreateRendezvous(opts, &created_rendezvous);
1415 if (!s.ok()) {
1416 done(s);
1417 return;
1418 }
1419 new_opts.rendezvous = created_rendezvous;
1420 new_opts.create_rendezvous = false;
1421 }
1422
1423 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1424 done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1425 new_opts.step_id, created_rendezvous);
1426 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1427 done = [rets, function_rets, done = std::move(done)](const Status& s) {
1428 Status status = s;
1429 if (status.ok()) {
1430 for (const auto& ret : *function_rets) {
1431 if (ret.index() == 0) {
1432 rets->push_back(absl::get<Tensor>(ret));
1433 } else {
1434 status.Update(errors::Internal(
1435 "Expect a Tensor as a function output but got a TensorShape."));
1436 break;
1437 }
1438 }
1439 }
1440 delete function_rets;
1441 done(status);
1442 };
1443 bool multi_device;
1444 {
1445 tf_shared_lock l(mu_);
1446 multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
1447 }
1448 if (multi_device) {
1449 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1450 InternalArgs* comp_args) -> Status {
1451 // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1452 for (const auto& it : comp_data.arg_indices) {
1453 if (it.index >= args.size()) {
1454 return errors::InvalidArgument(
1455 "index ", it.index, " is out of range [0, ", args.size(), ")");
1456 }
1457 if (it.sub_index >= 0) {
1458 const Tensor& t = args[it.index];
1459 if (t.dtype() != DT_RESOURCE) {
1460 return errors::InvalidArgument("Got unexpected sub_index ",
1461 it.sub_index, " for argument ",
1462 it.index);
1463 }
1464 const auto& handles = t.flat<ResourceHandle>();
1465 if (it.sub_index >= handles.size()) {
1466 return errors::InvalidArgument(
1467 "Sub_index ", it.sub_index, "is out of range [0,",
1468 handles.size(), ") for argument ", it.index);
1469 }
1470 comp_args->args.push_back(Tensor(handles(it.sub_index)));
1471 } else {
1472 comp_args->args.push_back(args[it.index]);
1473 }
1474 }
1475 return Status::OK();
1476 };
1477 return RunMultiDevice(new_opts, handle, function_rets, cleanup_items,
1478 std::move(done), std::move(get_component_args));
1479 }
1480 std::vector<FunctionArg> local_args;
1481 for (const auto& tensor : args) {
1482 local_args.push_back(tensor);
1483 }
1484 RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1485 std::move(done));
1486 }
1487
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1488 void ProcessFunctionLibraryRuntime::RunInternal(
1489 const FunctionLibraryRuntime::Options& opts,
1490 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1491 std::vector<FunctionRet>* rets,
1492 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1493 FunctionLibraryRuntime::DoneCallback done) const {
1494 FunctionLibraryRuntime* flr = nullptr;
1495 string target_device;
1496 FunctionLibraryRuntime::LocalHandle local_handle;
1497 {
1498 tf_shared_lock l(mu_);
1499 auto iter = function_data_.find(handle);
1500 if (iter == function_data_.end()) {
1501 done(errors::NotFound("Handle: ", handle, " not found."));
1502 return;
1503 }
1504 FunctionData* function_data = iter->second.get();
1505 target_device = function_data->target_device();
1506 local_handle = function_data->local_handle();
1507 }
1508
1509 if (!opts.remote_execution) {
1510 done(
1511 errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1512 "only be called for multi-device functions or "
1513 "for remote execution."));
1514 return;
1515 }
1516
1517 flr = GetFLR(target_device);
1518 if (flr != nullptr) {
1519 auto rendezvous = opts.rendezvous;
1520 string source_device = opts.source_device;
1521 DeviceContext* device_context;
1522 Status s = GetDeviceContext(source_device, &device_context);
1523 if (!s.ok()) {
1524 done(s);
1525 return;
1526 }
1527 int64_t src_incarnation, target_incarnation;
1528 s = GetDeviceIncarnation(source_device, &src_incarnation);
1529 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1530 if (!s.ok()) {
1531 done(s);
1532 return;
1533 }
1534
1535 std::vector<Tensor> local_args = GetLocalArgs(args);
1536
1537 // Send the args over to the target device.
1538 s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1539 local_args, device_context, opts.args_alloc_attrs,
1540 rendezvous);
1541 if (!s.ok()) {
1542 done(s);
1543 return;
1544 }
1545 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1546 opts.rets_alloc_attrs;
1547 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1548 flr->Run(opts, handle, local_args, remote_rets,
1549 [source_device, target_device, target_incarnation, rendezvous,
1550 device_context, rets_alloc_attrs, remote_rets, rets,
1551 done = std::move(done)](const Status& status) mutable {
1552 if (!status.ok()) {
1553 delete remote_rets;
1554 done(status);
1555 return;
1556 }
1557 int64_t num_returns = remote_rets->size();
1558 delete remote_rets;
1559 // Now receive the return values from the target.
1560 std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1561 ReceiveTensorsAsync(target_device, source_device, "ret_",
1562 target_incarnation, num_returns,
1563 device_context, rets_alloc_attrs, rendezvous,
1564 recv_tensors,
1565 TensorsToFunctionRetsDoneCallback(
1566 rets, recv_tensors, std::move(done)));
1567 });
1568 return;
1569 }
1570 if (parent_ != nullptr) {
1571 auto cleanup_item = absl::make_unique<CleanUpItem>();
1572 cleanup_item->device = target_device;
1573 cleanup_item->step_id = opts.step_id;
1574 cleanup_item->local_handle = local_handle;
1575 cleanup_items->emplace_back(std::move(cleanup_item));
1576 parent_->Run(opts, local_handle, args, rets, std::move(done));
1577 return;
1578 }
1579 done(errors::Internal("Could not find device"));
1580 }
1581
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1582 void ProcessFunctionLibraryRuntime::Run(
1583 const FunctionLibraryRuntime::Options& opts,
1584 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1585 FunctionLibraryRuntime::DoneCallback done) const {
1586 std::vector<Tensor> args;
1587 args.reserve(frame->num_args());
1588 for (size_t i = 0; i < frame->num_args(); ++i) {
1589 const Tensor* arg;
1590 Status s = frame->GetArg(i, &arg);
1591 args.emplace_back(*arg);
1592 if (!s.ok()) {
1593 done(s);
1594 }
1595 }
1596 std::vector<Tensor>* rets = new std::vector<Tensor>;
1597 rets->reserve(frame->num_retvals());
1598
1599 Run(opts, handle, args, rets,
1600
1601 [frame, rets, done = std::move(done)](const Status& status) {
1602 std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1603
1604 if (!status.ok()) {
1605 done(status);
1606 return;
1607 }
1608
1609 if (rets->size() != frame->num_retvals()) {
1610 done(errors::Internal(
1611 "Number of return values from function (", rets->size(),
1612 ") did not match expected number of return values (",
1613 frame->num_retvals(), ")."));
1614 return;
1615 }
1616
1617 for (size_t i = 0; i < frame->num_retvals(); ++i) {
1618 Status s = frame->SetRetval(i, (*rets)[i]);
1619 if (!s.ok()) {
1620 done(s);
1621 return;
1622 }
1623 }
1624 done(Status::OK());
1625 });
1626 }
1627
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1628 Status ProcessFunctionLibraryRuntime::RunSync(
1629 const FunctionLibraryRuntime::Options& opts,
1630 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1631 std::vector<Tensor>* rets) const {
1632 Notification n;
1633 Status s;
1634 Run(opts, handle, args, rets, [&n, &s](const Status& status) {
1635 s.Update(status);
1636 n.Notify();
1637 });
1638 n.WaitForNotification();
1639 return s;
1640 }
1641
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1642 Status ProcessFunctionLibraryRuntime::RunSync(
1643 const FunctionLibraryRuntime::Options& opts,
1644 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1645 Notification n;
1646 Status s;
1647 Run(opts, handle, frame, [&n, &s](const Status& status) {
1648 s.Update(status);
1649 n.Notify();
1650 });
1651 n.WaitForNotification();
1652 return s;
1653 }
1654
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1655 void ProcessFunctionLibraryRuntime::Run(
1656 const FunctionLibraryRuntime::Options& opts,
1657 FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1658 std::vector<FunctionRet>* rets,
1659 FunctionLibraryRuntime::DoneCallback done) const {
1660 bool has_remote_outputs = false;
1661 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1662 if (data != nullptr) {
1663 has_remote_outputs = data->has_remote_outputs;
1664 }
1665 if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1666 const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1667 std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1668 return Run(
1669 opts, handle, local_inputs, tensor_rets,
1670 TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1671 }
1672
1673 FunctionLibraryRuntime::Options new_opts = opts;
1674 Rendezvous* created_rendezvous = nullptr;
1675 if (!opts.rendezvous) {
1676 Status s = CreateRendezvous(opts, &created_rendezvous);
1677 if (!s.ok()) {
1678 done(s);
1679 return;
1680 }
1681 new_opts.rendezvous = created_rendezvous;
1682 new_opts.create_rendezvous = false;
1683 }
1684
1685 #if defined(IS_MOBILE_PLATFORM)
1686 done(errors::Unimplemented(
1687 "Remote inputs are not available on mobile devices."));
1688 return;
1689 #else // !IS_MOBILE_PLATFORM
1690 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1691 done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1692 created_rendezvous);
1693
1694 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1695 InternalArgs* comp_args) -> Status {
1696 for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1697 const FunctionArgIndex index = comp_data.arg_indices.at(i);
1698 Tensor tensor;
1699 if (args.GetLocalArg(index, &tensor).ok()) {
1700 comp_args->args.push_back(std::move(tensor));
1701 } else {
1702 eager::RemoteTensorHandle remote_handle;
1703 TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1704 comp_args->remote_args.emplace_back(
1705 absl::make_unique<eager::RemoteTensorHandle>(
1706 std::move(remote_handle)));
1707 comp_args->args.push_back(comp_args->remote_args.back().get());
1708 }
1709 }
1710 return Status::OK();
1711 };
1712 return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
1713 std::move(get_component_args));
1714 #endif // !IS_MOBILE_PLATFORM
1715 }
1716
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const1717 void ProcessFunctionLibraryRuntime::CleanUp(
1718 std::vector<std::unique_ptr<CleanUpItem>>* items,
1719 FunctionLibraryRuntime::DoneCallback done) const {
1720 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1721 for (auto& item : *items) {
1722 refcounted_done->Ref();
1723 auto* flr = GetFLR(item->device);
1724 if (flr != nullptr) {
1725 // TODO(fishx): cleanup state for local execution.
1726 refcounted_done->UpdateStatus(
1727 errors::Internal("Cleanup items shouldn't contain local item."));
1728 refcounted_done->Unref();
1729 } else if (parent_ != nullptr) {
1730 parent_->CleanUp(item->step_id, item->local_handle,
1731 [refcounted_done](const Status& status) {
1732 if (!status.ok()) {
1733 refcounted_done->UpdateStatus(status);
1734 }
1735 // refcounted_done is thread-safe
1736 refcounted_done->Unref();
1737 });
1738 } else {
1739 refcounted_done->UpdateStatus(
1740 errors::Internal("Could not find device in cleanup."));
1741 refcounted_done->Unref();
1742 }
1743 }
1744 refcounted_done->Unref();
1745 }
1746
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const1747 Status ProcessFunctionLibraryRuntime::Clone(
1748 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1749 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1750 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1751 bool skip_flib_def) const {
1752 if (skip_flib_def) {
1753 *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(
1754 lib_def_->default_registry(), FunctionDefLibrary{});
1755 } else {
1756 *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
1757 }
1758 *out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
1759 device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
1760 out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
1761 session_metadata_, rendezvous_factory_);
1762 {
1763 tf_shared_lock l(mu_);
1764 for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
1765 }
1766 return Status::OK();
1767 }
1768
1769 } // namespace tensorflow
1770