1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17 #include <algorithm>
18 #include <iterator>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/common_runtime/device_set.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/optimization_registry.h"
29 #include "tensorflow/core/common_runtime/partitioning_utils.h"
30 #include "tensorflow/core/common_runtime/placer.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
33 #include "tensorflow/core/common_runtime/rendezvous_util.h"
34 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
35 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
36 #include "tensorflow/core/framework/cancellation.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/graph/graph_partition.h"
47 #include "tensorflow/core/lib/core/blocking_counter.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/gtl/inlined_vector.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/random/random.h"
53 #include "tensorflow/core/platform/notification.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/dump_graph.h"
56 #include "tensorflow/core/util/ptr_util.h"
57 #include "tensorflow/core/util/reffed_status_callback.h"
58 #if !defined(IS_MOBILE_PLATFORM)
59 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
60 #endif // IS_MOBILE_PLATFORM
61
62 namespace tensorflow {
63
64 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
65
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)66 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
67 DistributedFunctionLibraryRuntime* parent, const string& function_name,
68 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
69 const FunctionLibraryRuntime::InstantiateOptions& options,
70 FunctionLibraryRuntime::DoneCallback done) {
71 {
72 mutex_lock l(mu_);
73 is_cross_process_ = true;
74 if (init_started_) {
75 init_done_.WaitForNotification();
76 done(init_result_);
77 return;
78 }
79 init_started_ = true;
80 }
81 parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
82 [this, done](const Status& s) {
83 init_done_.Notify();
84 done(s);
85 });
86 }
87
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)88 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
89 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
90 int graph_def_version, const FunctionLibraryDefinition* lib_def,
91 const OptimizerOptions& optimizer_options,
92 thread::ThreadPool* default_thread_pool,
93 DistributedFunctionLibraryRuntime* parent,
94 const SessionMetadata* session_metadata,
95 Rendezvous::Factory rendezvous_factory)
96 : parent_(parent),
97 env_(env),
98 config_(config ? absl::make_optional(*config) : absl::nullopt),
99 device_mgr_(device_mgr),
100 lib_def_(lib_def),
101 default_thread_pool_(default_thread_pool),
102 flr_map_(new std::unordered_map<Device*,
103 std::unique_ptr<FunctionLibraryRuntime>>),
104 next_handle_(0),
105 session_metadata_(session_metadata),
106 rendezvous_factory_(std::move(rendezvous_factory)),
107 optimizer_options_(optimizer_options),
108 graph_def_version_(graph_def_version) {
109 if (device_mgr == nullptr) {
110 (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
111 nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
112 graph_def_version, lib_def_, default_thread_pool, optimizer_options,
113 session_metadata_, this);
114 return;
115 }
116 InitializeDeviceAndFlr();
117 }
118
119 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)120 Status ProcessFunctionLibraryRuntime::SendTensors(
121 const string& source_device, const string& target_device,
122 const string& key_prefix, int64_t src_incarnation,
123 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
124 const std::vector<AllocatorAttributes>& alloc_attrs,
125 RendezvousInterface* rendezvous) {
126 std::vector<string> keys;
127 for (int i = 0; i < tensors_to_send.size(); ++i) {
128 string name = strings::StrCat(key_prefix, i);
129 string key = Rendezvous::CreateKey(source_device, src_incarnation,
130 target_device, name, FrameAndIter(0, 0));
131 keys.push_back(key);
132 }
133 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
134 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
135 return OkStatus();
136 }
137
138 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,int64_t num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)139 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
140 const string& source_device, const string& target_device,
141 const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
142 DeviceContext* device_context,
143 const std::vector<AllocatorAttributes>& alloc_attrs,
144 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
145 StatusCallback done) {
146 std::vector<string> keys;
147 for (int64_t i = 0; i < num_tensors; ++i) {
148 string name = strings::StrCat(key_prefix, i);
149 string key = Rendezvous::CreateKey(source_device, src_incarnation,
150 target_device, name, FrameAndIter(0, 0));
151 keys.push_back(key);
152 }
153 RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
154 received_tensors, std::move(done));
155 }
156
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)157 Status ProcessFunctionLibraryRuntime::GetRetTypes(
158 FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
159 FunctionLibraryRuntime* flr = nullptr;
160 {
161 tf_shared_lock l(mu_);
162 auto miter = mdevice_data_.find(h);
163 if (miter != mdevice_data_.end()) {
164 *ret_types = miter->second->ret_types_;
165 return OkStatus();
166 }
167 auto fiter = function_data_.find(h);
168 if (fiter != function_data_.end()) {
169 flr = GetFLR(fiter->second->target_device());
170 }
171 }
172 if (flr != nullptr) {
173 return flr->GetRetTypes(h, ret_types);
174 }
175 return errors::InvalidArgument("Handle ", h, " not found.");
176 }
177
GetDeviceIncarnation(const string & device_name,int64_t * incarnation) const178 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
179 const string& device_name, int64_t* incarnation) const {
180 FunctionLibraryRuntime* flr = GetFLR(device_name);
181 if (flr == nullptr) {
182 return errors::InvalidArgument("Device name: ", device_name, " not found.");
183 }
184 *incarnation = flr->device()->attributes().incarnation();
185 return OkStatus();
186 }
187
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const188 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
189 const string& device_name, DeviceContext** device_context) const {
190 *device_context = nullptr;
191 FunctionLibraryRuntime* flr = GetFLR(device_name);
192 if (flr == nullptr) {
193 return errors::InvalidArgument("Device name: ", device_name, " not found.");
194 }
195 Device* device = flr->device();
196 string device_type = device->parsed_name().type;
197 if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
198 // "TPU_SYSTEM" indicates that `device` is a CPU.
199 return OkStatus();
200 }
201
202 if (device->IsRemoteCallAllowed()) {
203 auto* dev_info = flr->device()->tensorflow_accelerator_device_info();
204 if (dev_info) {
205 *device_context = dev_info->default_context;
206 return OkStatus();
207 }
208 }
209
210 return errors::Internal("Device type: ", device_type,
211 " is currently unsupported for remote ",
212 "function executions");
213 }
214
InitializeDeviceAndFlr()215 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
216 // Reset device_set_ by one of the two following scenarios:
217 // 1) Both cluster-FLR and its remote_device_mgr is available: include local
218 // devices (if any) from the local device_mgr_ as Device type, and include
219 // remote devices from cluster's remote_device_mgr as RemoteDevice type.
220 // 2) Include local devices from the local device_mgr_.
221 // In both scenarios, no device is added more than one times.
222 mutex_lock l(mu_);
223 device_set_ = std::make_shared<DeviceSet>();
224 if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
225 for (auto d : parent_->remote_device_mgr()->ListDevices()) {
226 Device* device = nullptr;
227 if (device_mgr_->LookupDevice(d->name(), &device) == OkStatus()) {
228 // If this device exists in device_mgr, i.e., a local device,
229 // add this device from the instance included in device_mgr_
230 device_set_->AddDevice(device);
231 } else {
232 device_set_->AddDevice(d);
233 }
234 }
235 } else {
236 for (auto d : device_mgr_->ListDevices()) {
237 device_set_->AddDevice(d);
238 }
239 }
240
241 // Update flr_map_ by adding new devices
242 for (Device* d : device_mgr_->ListDevices()) {
243 if ((*flr_map_)[d] == nullptr) {
244 (*flr_map_)[d] = NewFunctionLibraryRuntime(
245 device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
246 graph_def_version_, lib_def_, default_thread_pool_,
247 optimizer_options_, session_metadata_, this);
248 }
249 }
250 }
251
GetFLR(const string & device_name) const252 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
253 const string& device_name) const {
254 Device* device = nullptr;
255 if (device_name != kDefaultFLRDevice) {
256 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
257 VLOG(4) << "Could not find device: " << device_name;
258 return nullptr;
259 }
260 }
261 const auto& iter = flr_map_->find(device);
262 if (iter == flr_map_->end()) {
263 VLOG(1) << "Could not find device: " << device_name
264 << "in the local process.";
265 return nullptr;
266 }
267 return iter->second.get();
268 }
269
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)270 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
271 const string& function_key, const string& device_name,
272 FunctionLibraryRuntime::LocalHandle local_handle) {
273 mutex_lock l(mu_);
274 return AddHandleLocked(function_key, device_name, local_handle);
275 }
276
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)277 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
278 const string& function_key, const string& device_name,
279 FunctionLibraryRuntime::LocalHandle local_handle) {
280 auto h = next_handle_;
281 function_data_[h] =
282 std::make_unique<FunctionData>(device_name, local_handle, function_key);
283 table_[function_key] = h;
284 next_handle_++;
285 return h;
286 }
287
288 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)289 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
290 std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
291 mutex_lock l(mu_);
292 auto h = next_handle_;
293 mdevice_data_[h] = std::move(data);
294 table_[function_key] = h;
295 next_handle_++;
296 return h;
297 }
298
HasMultiDeviceHandle(FunctionLibraryRuntime::Handle handle) const299 bool ProcessFunctionLibraryRuntime::HasMultiDeviceHandle(
300 FunctionLibraryRuntime::Handle handle) const {
301 bool multi_device;
302 {
303 tf_shared_lock l(mu_);
304 multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
305 }
306 return multi_device;
307 }
308
GetHandle(const string & function_key) const309 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
310 const string& function_key) const {
311 tf_shared_lock l(mu_);
312 return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
313 }
314
315 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const316 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
317 const string& device_name, FunctionLibraryRuntime::Handle handle,
318 bool include_multi_device) const {
319 tf_shared_lock l(mu_);
320
321 auto miter = mdevice_data_.find(handle);
322 if (miter != mdevice_data_.end()) {
323 if (!include_multi_device) return kInvalidLocalHandle;
324
325 const MultiDeviceFunctionData& data = *miter->second;
326 if (data.glue_.size() != 1) return kInvalidLocalHandle;
327
328 const auto& pair = *data.glue_.begin();
329 const string& func_device_name = pair.first;
330 const ComponentFunctionData& component_data = pair.second;
331 if (func_device_name != device_name) return kInvalidLocalHandle;
332
333 // Replace the given handle with the handle for the single component
334 // function.
335 handle = component_data.handle;
336 }
337
338 auto iter = function_data_.find(handle);
339 if (iter == function_data_.end()) {
340 return kInvalidLocalHandle;
341 }
342 FunctionData* function_data = iter->second.get();
343 if (function_data->target_device() != device_name) {
344 return kInvalidLocalHandle;
345 }
346 return function_data->local_handle();
347 }
348
GetDeviceName(FunctionLibraryRuntime::Handle handle) const349 string ProcessFunctionLibraryRuntime::GetDeviceName(
350 FunctionLibraryRuntime::Handle handle) const {
351 tf_shared_lock l(mu_);
352 auto iter = function_data_.find(handle);
353 CHECK(iter != function_data_.end());
354 FunctionData* function_data = iter->second.get();
355 return function_data->target_device();
356 }
357
358 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const359 ProcessFunctionLibraryRuntime::IsMultiDevice(
360 FunctionLibraryRuntime::Handle handle) const {
361 tf_shared_lock l(mu_);
362 const auto& it = mdevice_data_.find(handle);
363 if (it != mdevice_data_.end()) {
364 return it->second.get();
365 }
366 return nullptr;
367 }
368
369 namespace {
370 // Sets `group` to the first colocation group specified in `node`. If no
371 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)372 void GetColocationGroup(const Node* node, string* group) {
373 // We hoist the conversion from C-style string literal to string here,
374 // so that we can avoid the many repeated calls to strlen().
375 static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
376 const AttrValue* attr_value =
377 node->attrs().Find(kColocationAttrNameStringPiece);
378 if (attr_value != nullptr && attr_value->has_list() &&
379 attr_value->list().s_size() > 0) {
380 *group = attr_value->list().s(0);
381 }
382 }
383
AssignedOrRequestedDeviceName(const Node & node)384 const string* AssignedOrRequestedDeviceName(const Node& node) {
385 if (node.has_assigned_device_name()) {
386 return &node.assigned_device_name();
387 }
388 return &node.requested_device();
389 }
390
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)391 Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>&
392 input_resource_dtypes_and_shapes,
393 const std::vector<Node*>& arg_nodes) {
394 for (Node* n : arg_nodes) {
395 int index;
396 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
397 DataType dtype;
398 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
399 if (dtype == DT_RESOURCE) {
400 auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
401 if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
402 AttrValue dtype_attr_value;
403 dtype_attr_value.mutable_list()->add_type(
404 dtype_and_shape_iter->second.dtype);
405 n->AddAttr("_handle_dtypes", dtype_attr_value);
406 TensorShapeProto shape_proto;
407 dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
408 AttrValue shape_attr_value;
409 *shape_attr_value.mutable_list()->add_shape() = shape_proto;
410 n->AddAttr("_handle_shapes", shape_attr_value);
411 }
412 }
413 }
414 return OkStatus();
415 }
416
417 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)418 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
419 std::vector<Tensor> tensors;
420 for (const auto& arg : args) {
421 if (arg.index() == 0) {
422 tensors.push_back(absl::get<Tensor>(arg));
423 }
424 }
425 return tensors;
426 }
427
428 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)429 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
430 std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
431 FunctionLibraryRuntime::DoneCallback done) {
432 return [rets, tensors, done = std::move(done)](const Status& s) {
433 if (s.ok()) {
434 for (const auto& t : *tensors) {
435 rets->push_back(t);
436 }
437 }
438 delete tensors;
439 done(s);
440 };
441 }
442
443 // Push Tensors in `function_rets` into `tensors`.
FunctionRetsToTensors(const std::vector<FunctionRet> * function_rets,std::vector<Tensor> * tensors)444 Status FunctionRetsToTensors(const std::vector<FunctionRet>* function_rets,
445 std::vector<Tensor>* tensors) {
446 for (const auto& ret : *function_rets) {
447 if (ret.index() != 0) {
448 return errors::Internal(
449 "Expect a Tensor as a function output but got a TensorShape.");
450 }
451 tensors->push_back(absl::get<Tensor>(ret));
452 }
453 return OkStatus();
454 }
455
456 } // anonymous namespace
457
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,const FunctionLibraryDefinition * lib_def,Device * default_device)458 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
459 const std::vector<string>& input_devices,
460 const std::vector<string>& output_devices, const DeviceSet& device_set,
461 const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
462 const FunctionLibraryDefinition* lib_def, Device* default_device) {
463 // If output_devices are not specified, we want to set the output device
464 // based on the device of the output producing node. The output producing
465 // node can be an arg node because functions can simply return their
466 // arguments. To make sure that the output producing nodes have assigned
467 // devices, we assign them to arguments first.
468 for (Node* node : arg_nodes) {
469 const AttrValue* attr_value;
470 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
471 int64_t index = attr_value->i();
472 node->set_assigned_device_name(input_devices[index]);
473 }
474
475 for (Node* node : ret_nodes) {
476 if (output_devices.empty()) {
477 DataType dtype;
478 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
479
480 VLOG(3) << "Trying to determine device for node " << node->name()
481 << "[T=" << DataTypeString(dtype) << "]";
482
483 // If output_devices are empty, the node producing retval
484 // must have explicitly assigned device or a colocation constraint
485 // to a node with explicitly assigned device.
486 for (const auto& it : node->in_edges()) {
487 if (it->IsControlEdge()) continue;
488
489 Node* src_node = it->src();
490 const string* src_device = AssignedOrRequestedDeviceName(*src_node);
491 string colocation_group = "";
492 GetColocationGroup(src_node, &colocation_group);
493 VLOG(3) << "Considering src: " << src_node->name()
494 << " src_device: " << *src_device
495 << " colo group: " << colocation_group;
496 while (src_device->empty() && colocation_group.empty() &&
497 src_node->IsIdentity()) {
498 // Only follows the real data input of Identity, not control edges.
499 Node* input_node;
500 TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
501 src_node = input_node;
502
503 src_device = AssignedOrRequestedDeviceName(*src_node);
504 GetColocationGroup(src_node, &colocation_group);
505 VLOG(3) << "Considering src: " << src_node->name()
506 << " src_device: " << *src_device
507 << " colo group: " << colocation_group;
508 }
509
510 // If resource is produced by a function call node, we can't trust
511 // source node device assignment, because multi-device functions can
512 // return resource placed on multiple devices. In such case we leave
513 // retval device assignment empty, and rely on placer to infer correct
514 // assignment based on actual output device.
515 const bool can_use_src_node_device =
516 !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node));
517
518 if (!colocation_group.empty()) {
519 AttrValue::ListValue colo_attr;
520 colo_attr.add_s(colocation_group);
521 std::vector<string> colo_slice = {colocation_group};
522 node->AddAttr(kColocationAttrName, colo_slice);
523 } else if (!src_device->empty() && can_use_src_node_device) {
524 // Do not copy device from src node for variants, unless it is a no-op
525 // forward from input to output. This gets handled in
526 // colocation_graph.cc which has special logic for correctly placing
527 // _Retvals for various variant types.
528 if (dtype == DT_VARIANT && !src_node->IsArg()) {
529 continue;
530 }
531 // src_device can be a partially specified device. Find the
532 // matching device in the device_set.
533 DeviceNameUtils::ParsedName parsed;
534 if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
535 return errors::InvalidArgument(
536 "Failed to parse explicit device specification ", *src_device);
537 }
538 std::vector<Device*> matching_devices;
539 device_set.FindMatchingDevices(parsed, &matching_devices);
540 if (matching_devices.empty()) {
541 if (default_device != nullptr) {
542 matching_devices.push_back(default_device);
543 } else {
544 return errors::InvalidArgument(
545 "Unable to find any devices for spec ", *src_device);
546 }
547 } else if (matching_devices.size() != 1) {
548 bool on_same_task = true;
549 for (int i = 1; i < matching_devices.size(); ++i) {
550 if (!DeviceNameUtils::IsSameAddressSpace(
551 matching_devices.at(0)->parsed_name(),
552 matching_devices.at(i)->parsed_name())) {
553 on_same_task = false;
554 break;
555 }
556 }
557 // If the src node of an output is assigned to a address space (e.g.
558 // py_func), rely on placer to assign a device to the output.
559 if (on_same_task) {
560 continue;
561 }
562 // Compare with default_device if it has a narrower scope matching
563 // requested device.
564 if (default_device != nullptr) {
565 int colocated_on_default_device = 0;
566 for (int i = 0; i < matching_devices.size(); ++i) {
567 if (DeviceNameUtils::IsSameAddressSpace(
568 default_device->parsed_name(),
569 matching_devices.at(i)->parsed_name())) {
570 colocated_on_default_device++;
571 }
572 }
573 // Continue to raise error if multiple colocated devices are
574 // found.
575 if (colocated_on_default_device == 1) {
576 continue;
577 }
578 }
579 // Convert a vector of devices to a string.
580 // Using absl::StrJoin did not work in Android builds.
581 string devices = "[";
582 for (Device* device : matching_devices) {
583 devices.append(device->name());
584 devices.append(", ");
585 }
586 if (devices.size() > 2) {
587 devices.resize(devices.size() - 2);
588 }
589 devices.append("]");
590
591 return errors::InvalidArgument(
592 *src_device,
593 "When FunctionLibraryRuntime::Options.output_devices are "
594 "not specified for a multi-device function, the device "
595 "specification on the output node must match exactly one "
596 "device. Matched devices are ",
597 devices);
598 }
599 VLOG(3) << "Setting output device to " << matching_devices[0]->name()
600 << " for node " << SummarizeNode(*node);
601 node->set_assigned_device_name(matching_devices[0]->name());
602 } else if (!src_device->empty() && !can_use_src_node_device) {
603 VLOG(3) << "Did not set device for a resource output node "
604 << SummarizeNode(*node);
605 }
606 }
607 } else {
608 const AttrValue* attr_value;
609 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
610 int64_t index = attr_value->i();
611 // output_devices size is checked in InstantiateMultiDevice
612 DCHECK_GT(output_devices.size(), index);
613 VLOG(3) << "Setting output device to " << output_devices[index]
614 << " for return at index " << index;
615 node->set_assigned_device_name(output_devices[index]);
616 }
617 }
618 return OkStatus();
619 }
620
621 namespace {
622
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)623 Status ValidateNoListArguments(
624 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
625 const string& function_name) {
626 for (const OpDef::ArgDef& arg : args) {
627 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
628 return errors::InvalidArgument(
629 "Function ", function_name, " has an ", arg_type, " named \"",
630 arg.name(),
631 "\" that is a list of tensors."
632 " Multi-device functions support only single-tensor inputs "
633 " and outputs");
634 }
635 }
636 return OkStatus();
637 }
638
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)639 Status ValidateMultiDeviceOptions(
640 const FunctionDef& fdef,
641 const FunctionLibraryRuntime::InstantiateOptions& options) {
642 const OpDef& signature = fdef.signature();
643 // Multi-device functions currently do not support list inputs or outputs.
644 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
645 signature.name()));
646 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
647 signature.name()));
648 if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
649 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
650 return errors::Unimplemented(
651 "Function '", signature.name(), "' has `",
652 FunctionLibraryDefinition::kIntsOnDeviceAttr,
653 "` attribute set. This attribute is not currently supported by "
654 "multi-device functions.");
655 }
656 if (options.input_devices.size() != signature.input_arg_size()) {
657 return errors::InvalidArgument(
658 "InstantiateOptions.input_devices must have the same length "
659 "as the number of arguments: input_devices length = ",
660 options.input_devices.size(),
661 " number of arguments = ", signature.input_arg_size());
662 }
663 if (!options.output_devices.empty() &&
664 options.output_devices.size() != signature.output_arg_size()) {
665 return errors::InvalidArgument(
666 "InstantiateOptions.output_devices must either be empty or have the "
667 "same length as the number of arguments: output_devices length = ",
668 options.output_devices.size(),
669 " number of arguments = ", signature.output_arg_size());
670 }
671 return OkStatus();
672 }
673
674 } // anonymous namespace
675
676 ProcessFunctionLibraryRuntime::AsyncAttributes::Summary
Summarize(const Graph * graph)677 ProcessFunctionLibraryRuntime::AsyncAttributes::Summarize(const Graph* graph) {
678 bool has_send_op = false;
679 bool has_recv_op = false;
680 bool has_unsafe_op = false;
681 for (const Node* node : graph->nodes()) {
682 if (node->IsSend() || node->IsHostSend()) {
683 has_send_op = true;
684 }
685 if (node->IsRecv() || node->IsHostRecv()) {
686 has_recv_op = true;
687 }
688 if (!ValidateOpIsSafeForSyncExecution(*node,
689 allow_control_flow_sync_execution())
690 .ok()) {
691 has_unsafe_op = true;
692 }
693 }
694 // (1) Anything completely unsupported?
695 if (has_unsafe_op) {
696 metrics::IncrementTestCounter("subgraph_async_summary", "unsafe_op");
697 return AsyncAttributes::kAsyncRequired;
698 }
699 // (2) That only leaves send/recv. If neither, then it's safe.
700 if (!has_send_op && !has_recv_op) {
701 metrics::IncrementTestCounter("subgraph_async_summary", "safe_for_sync");
702 return AsyncAttributes::kSafeForSync;
703 }
704 // (3) If each subgraph has only send or only recv, then it's possible to
705 // order them to run sequentially without deadlock.
706 if (has_send_op && !has_recv_op) {
707 metrics::IncrementTestCounter("subgraph_async_summary", "send_only");
708 return AsyncAttributes::kSendOnly;
709 }
710 if (has_recv_op && !has_send_op) {
711 metrics::IncrementTestCounter("subgraph_async_summary", "recv_only");
712 return AsyncAttributes::kRecvOnly;
713 }
714 // Otherwise, assume it's unsupported.
715 metrics::IncrementTestCounter("subgraph_async_summary", "other");
716 return AsyncAttributes::kAsyncRequired;
717 }
718
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)719 Status GetGraphAndArgRets(
720 const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
721 const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
722 std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
723 std::vector<string>* ret_node_names, DataTypeVector* ret_types,
724 std::vector<string>* control_ret_node_names) {
725 std::unique_ptr<FunctionBody> fbody;
726 // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
727 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
728 if (!fbody) {
729 LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
730 return errors::Internal("Failed to construct FunctionBody for ",
731 function_name);
732 }
733 *graph = std::unique_ptr<Graph>(fbody->graph);
734 arg_nodes->reserve(fbody->arg_nodes.size());
735 std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
736 std::back_inserter(*arg_nodes));
737 ret_nodes->reserve(fbody->ret_nodes.size());
738 std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
739 std::back_inserter(*ret_nodes));
740 fbody->graph = nullptr;
741 ret_node_names->reserve(fbody->ret_nodes.size());
742 for (const Node* node : fbody->ret_nodes) {
743 ret_node_names->push_back(node->name());
744 }
745 for (const auto& ret_type : fbody->ret_types) {
746 ret_types->push_back(ret_type);
747 }
748 control_ret_node_names->reserve(fbody->control_ret_nodes.size());
749 for (const Node* node : fbody->control_ret_nodes) {
750 control_ret_node_names->push_back(node->name());
751 }
752 return OkStatus();
753 }
754
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)755 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
756 const string& function_name, AttrSlice attrs,
757 const FunctionLibraryRuntime::InstantiateOptions& options,
758 FunctionLibraryRuntime::Handle* handle) {
759 // Check if this function has already been instantiated.
760 const string& function_key = Canonicalize(function_name, attrs, options);
761
762 {
763 mutex_lock l(mu_);
764 const auto& it = table_.find(function_key);
765 if (it != table_.end()) {
766 *handle = it->second;
767 ++mdevice_data_[*handle]->instantiation_counter_;
768 return OkStatus();
769 }
770 }
771
772 VLOG(1) << "Instantiating MultiDevice function \"" << function_name
773 << "\" on default device \"" << options.target << "\"";
774 if (VLOG_IS_ON(3)) {
775 int index = 0;
776 VLOG(3) << "Requested input devices:";
777 for (const string& device : options.input_devices) {
778 VLOG(3) << " [input " << index++ << "] " << device;
779 }
780 index = 0;
781 VLOG(3) << "Requested output devices:";
782 for (const string& device : options.output_devices) {
783 VLOG(3) << " [output " << index++ << "] " << device;
784 }
785 }
786
787 const FunctionLibraryDefinition* lib_def =
788 options.lib_def == nullptr ? lib_def_ : options.lib_def;
789
790 const FunctionDef* fdef = lib_def->Find(function_name);
791 if (fdef == nullptr) {
792 return errors::InvalidArgument("Failed to find function \"", function_name,
793 "\" in function library: ", lib_def);
794 }
795
796 TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
797
798 std::unique_ptr<Graph> graph;
799 std::vector<Node*> arg_nodes, ret_nodes;
800 std::vector<string> ret_node_names;
801 DataTypeVector ret_types;
802 std::vector<string> control_ret_node_names;
803
804 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
805 function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
806 &ret_node_names, &ret_types, &control_ret_node_names));
807
808 GraphDef graph_def;
809 graph->ToGraphDef(&graph_def);
810 FunctionLibraryDefinition reachable_lib_def =
811 lib_def->ReachableDefinitions(graph_def);
812 *graph_def.mutable_library() = reachable_lib_def.ToProto();
813 if (options.graph_collector != nullptr) {
814 options.graph_collector->CollectRawGraph(graph_def);
815 }
816
817 Device* default_device = nullptr;
818 if (options.default_device_to_target && !options.target.empty()) {
819 // Make the `target` device the default device if nothing else is hard
820 // coded. This allows the same function definition to be specialized to
821 // different devices depending on the `PartitionedCallOp` device.
822 FunctionLibraryRuntime* flr = GetFLR(options.target);
823 if (flr == nullptr) {
824 return errors::InvalidArgument(
825 "Cannot instantiate multi-device function with target device ",
826 options.target);
827 }
828 default_device = flr->device();
829 }
830
831 // Mark each node in the graph to be compiled by specified device.
832 if (!options.xla_compile_device_type.empty()) {
833 for (Node* node : graph->op_nodes()) {
834 node->AddAttr("_xla_compile_device_type",
835 options.xla_compile_device_type);
836 }
837 }
838
839 const std::shared_ptr<DeviceSet> dev_set = device_set();
840
841 TF_RETURN_IF_ERROR(
842 SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
843 TF_RETURN_IF_ERROR(PinArgsAndRets(
844 options.input_devices, options.output_devices, *dev_set, arg_nodes,
845 ret_nodes, lib_def_,
846 options.config_proto.allow_soft_placement() ? default_device : nullptr));
847
848 auto data = std::make_unique<MultiDeviceFunctionData>(
849 function_name, function_key, ret_node_names.size(),
850 std::move(reachable_lib_def), std::move(ret_types));
851
852 // The runtime shouldn't depend on duplication between the function library
853 // owned by the graph and the one owned by the runtime. To ensure this, for
854 // now we ensure that the graph function library is empty and the runtime
855 // library receives the query from LookUps on the graph function library.
856 graph->mutable_flib_def()->set_default_registry(&data->lib_def_);
857 graph->mutable_flib_def()->Clear();
858
859 // Do not run function/graph optimization passes for component functions,
860 // since they have already processed the main function.
861 const bool should_run_optimization_passes = !options.is_component_function;
862 if (!should_run_optimization_passes) {
863 VLOG(1) << "Skipping function/graph optimization passes when instantiating "
864 "component function "
865 << function_name;
866 }
867
868 // Mapping from a function body node name to the control output name.
869 std::unordered_map<string, string> node_name_to_control_ret;
870
871 bool control_rets_updated = false;
872 if (should_run_optimization_passes) {
873 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
874 *dev_set, options.config_proto, &graph, &data->lib_def_,
875 &control_ret_node_names, &control_rets_updated));
876 }
877
878 if (control_rets_updated) {
879 // Function graph pass may have resulted in different nodes/node names for
880 // control rets.
881 for (const auto& control_ret : control_ret_node_names) {
882 node_name_to_control_ret.emplace(control_ret, control_ret);
883 }
884 } else {
885 for (const auto& control_ret : fdef->control_ret()) {
886 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
887 }
888 }
889
890 GraphOptimizationPassOptions optimization_options;
891 // TODO(iga): Thread other relevant options from SessionOptions.
892 SessionOptions session_options;
893 session_options.env = env_;
894 session_options.config = options.config_proto;
895 optimization_options.session_options = &session_options;
896 optimization_options.graph = &graph;
897 optimization_options.flib_def = &data->lib_def_;
898 optimization_options.device_set = dev_set.get();
899 optimization_options.is_function_graph = true;
900 std::vector<CompositeDevice*> composite_devices;
901 {
902 tf_shared_lock l(mu_);
903 for (auto* d : composite_devices_) composite_devices.push_back(d);
904 }
905 optimization_options.composite_devices = &composite_devices;
906 optimization_options.default_function_device = default_device;
907 optimization_options.function_def = fdef;
908 optimization_options.shape_inference_on_tfe_dialect_import =
909 options.shape_inference_on_tfe_dialect_import;
910
911 DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
912 if (should_run_optimization_passes) {
913 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
914 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
915 }
916
917 // TODO(b/124993244): Smartly merge options in nested defuns, and raise
918 // exceptions/warnings in case where nested function call options are ignored.
919 DumpGraph("Before calling Placer", graph.get());
920 Placer placer(graph.get(), function_name, optimization_options.flib_def,
921 dev_set.get(), default_device,
922 options.config_proto.allow_soft_placement(),
923 options.config_proto.log_device_placement());
924 TF_RETURN_IF_ERROR(placer.Run());
925
926 DumpGraph("Before running POST_PLACEMENT passes", graph.get());
927 if (should_run_optimization_passes) {
928 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
929 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
930 }
931
932 Device* cpu_device;
933 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
934
935 if (options.optimize_graph_fn) {
936 DumpGraph("Before running graph optimization fn", graph.get());
937 Status status = options.optimize_graph_fn(
938 std::move(ret_node_names), std::move(control_ret_node_names),
939 &data->lib_def_, *dev_set, cpu_device, &graph);
940 if (!status.ok()) {
941 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
942 << status.ToString();
943 }
944 DumpGraph("After optimization", graph.get());
945 }
946
947 DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
948 if (should_run_optimization_passes) {
949 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
950 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
951 }
952
953 // Expand the nodes assigned to a CompositeDevice before graph partition to
954 // avoid generating a subgraph on a virtual device for execution.
955 // This transformation should happen as late as possible, in order to run as
956 // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
957 // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
958 TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
959 options.composite_devices, graph.get()));
960
961 if (options.graph_collector != nullptr) {
962 GraphDef def;
963 graph->ToGraphDef(&def);
964 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
965 options.graph_collector->CollectOptimizedGraph(def);
966 }
967
968 VLOG(4) << "Main function graph to be partitioned:";
969 VLOG(4) << DebugString(graph->ToGraphDefDebug());
970
971 std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
972 TF_RETURN_IF_ERROR(
973 PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
974
975 for (const auto& pair : subgraphs) {
976 DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
977 pair.first, ")"),
978 pair.second.get());
979 }
980 optimization_options.graph = nullptr;
981 optimization_options.device_set = nullptr;
982 optimization_options.partition_graphs = &subgraphs;
983 // Normally POST_PARTITIONING passes are run by distributed workers.
984 // Distributed workers are currently not supported in this code path, so we
985 // run the passes here.
986 if (should_run_optimization_passes) {
987 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
988 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
989 }
990 for (const auto& pair : subgraphs) {
991 const auto* optimized_subgraph = pair.second.get();
992 DumpGraph(
993 strings::StrCat("After all optimization passes (", pair.first, ")"),
994 optimized_subgraph);
995 if (VLOG_IS_ON(1)) {
996 DumpGraphDefToFile(
997 strings::StrCat("pflr_after_all_optimization_passes_",
998 reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
999 pair.first),
1000 optimized_subgraph->ToGraphDefDebug());
1001 }
1002 }
1003
1004 if (options.graph_collector != nullptr) {
1005 for (const auto& pair : subgraphs) {
1006 GraphDef def;
1007 pair.second->ToGraphDef(&def);
1008 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
1009 options.graph_collector->CollectPartitionedGraph(def);
1010 }
1011 }
1012
1013 // We must preserve control returns in each of the function components,
1014 // otherwise after function inlining we might prune side-effectful nodes.
1015 const auto control_ret =
1016 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
1017 const auto it = node_name_to_control_ret.find(n->name());
1018 return it != node_name_to_control_ret.end()
1019 ? absl::make_optional<string>(it->second)
1020 : absl::nullopt;
1021 };
1022
1023 int i = 0;
1024 // Generate a random function_name to avoid one function reuse the partition
1025 // function instantiated by another function.
1026 FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
1027 FunctionNameGenerator name_generator(
1028 data_lib_def, absl::StrCat(function_name, "_", random::New64()));
1029 auto num_subgraphs = subgraphs.size();
1030 gtl::InlinedVector<Status, 4> instantiate_status(num_subgraphs);
1031 BlockingCounter counter(static_cast<int>(num_subgraphs));
1032 auto runner = [this, num_subgraphs](std::function<void()> fn) {
1033 // NOTE: Only use thread pool to instantiate sub-function when there are
1034 // more than 8 sub-functions. We want to avoid cost of switching thread when
1035 // there are only a few sub-functions.
1036 if (default_thread_pool_ != nullptr && num_subgraphs > 8) {
1037 default_thread_pool_->Schedule(fn);
1038 } else {
1039 fn();
1040 }
1041 };
1042
1043 // Before instantiating component functions, determine synchronous execution.
1044 data->enable_sync_execution = false;
1045 if (options.allow_small_function_optimizations) {
1046 data->enable_sync_execution = true;
1047 for (const auto& pair : subgraphs) {
1048 ComponentFunctionData* comp_data = &data->glue_[pair.first];
1049 const Graph* subgraph = pair.second.get();
1050 comp_data->async_attributes =
1051 AsyncAttributes(subgraph, options.allow_control_flow_sync_execution);
1052 if (comp_data->async_attributes.summary() ==
1053 AsyncAttributes::kAsyncRequired) {
1054 data->enable_sync_execution = false;
1055 }
1056 }
1057 }
1058
1059 // Instantiate each component function (subgraph).
1060 for (const auto& pair : subgraphs) {
1061 Status* status = &instantiate_status[i];
1062 string unique_name = name_generator.GetName();
1063 ComponentFunctionData* comp_data = &data->glue_[pair.first];
1064 runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
1065 &control_ret, &options, status, &counter, &data] {
1066 const string& target = pair.first;
1067
1068 const string& device_type =
1069 dev_set->FindDeviceByName(target)->device_type();
1070 Graph* subgraph = pair.second.get();
1071
1072 bool ints_on_device =
1073 (device_type == "TPU" || device_type == "XLA_CPU" ||
1074 device_type == "XLA_GPU" || options.int_args_and_retvals_on_device);
1075 status->Update(UpdateArgAndRetvalMetadata(
1076 subgraph, &comp_data->arg_indices, &comp_data->ret_indices,
1077 &comp_data->arg_alloc_attrs, &comp_data->ret_alloc_attrs,
1078 ints_on_device));
1079 if (!status->ok()) {
1080 counter.DecrementCount();
1081 return;
1082 }
1083 FunctionDef shard;
1084 status->Update(
1085 GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
1086 if (!status->ok()) {
1087 counter.DecrementCount();
1088 return;
1089 }
1090 status->Update(data_lib_def->AddFunctionDef(shard));
1091 if (!status->ok()) {
1092 counter.DecrementCount();
1093 return;
1094 }
1095 FunctionLibraryRuntime::InstantiateOptions opts;
1096 opts.executor_type = options.executor_type;
1097 opts.target = target;
1098 opts.lib_def = data_lib_def;
1099 opts.create_kernels_eagerly = options.create_kernels_eagerly;
1100 opts.state_handle = options.state_handle;
1101 opts.allow_small_function_optimizations = data->enable_sync_execution;
1102 opts.allow_control_flow_sync_execution =
1103 options.allow_control_flow_sync_execution;
1104 AttrValue ints_on_device_attr;
1105 ints_on_device_attr.set_b(options.int_args_and_retvals_on_device);
1106 shard.mutable_attr()->insert(
1107 {FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr});
1108 auto attrs = AttrSlice(&shard.attr());
1109 VLOG(1) << "Start instantiating component function " << unique_name
1110 << " on device " << target;
1111 VLOG(4) << DebugString(shard);
1112
1113 auto* component_handle = new FunctionLibraryRuntime::Handle;
1114 auto done = [this, status, unique_name, comp_data, component_handle,
1115 &data, &counter](const Status& s) {
1116 status->Update(s);
1117
1118 VLOG(1) << "Finished instantiating component function " << unique_name
1119 << " with handle " << *component_handle << " status: " << s;
1120 if (status->ok()) {
1121 {
1122 mutex_lock l(mu_);
1123 if (function_data_[*component_handle]->is_cross_process()) {
1124 data->is_cross_process_ = true;
1125 }
1126 }
1127 comp_data->handle = *component_handle;
1128 }
1129 delete component_handle;
1130 counter.DecrementCount();
1131 };
1132
1133 FunctionLibraryRuntime* flr = GetFLR(opts.target);
1134 if (flr != nullptr) {
1135 // Initialize local function synchronously.
1136 Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
1137 done(s);
1138 } else {
1139 opts.ret_indices = comp_data->ret_indices;
1140 // Initialize remote function asynchronously.
1141 InstantiateRemote(unique_name, attrs, opts, component_handle, done);
1142 }
1143 });
1144 i += 1;
1145 }
1146 counter.Wait();
1147 StatusGroup group;
1148 for (auto& status : instantiate_status) {
1149 group.Update(status);
1150 }
1151 TF_RETURN_IF_ERROR(group.as_summary_status());
1152
1153 *handle = AddMultiDeviceHandle(std::move(data), function_key);
1154 VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1155 << "\" with handle " << *handle;
1156 return OkStatus();
1157 }
1158
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1159 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1160 FunctionLibraryRuntime::Handle handle,
1161 std::vector<Device*>* output_devices) const {
1162 MultiDeviceFunctionData* data = IsMultiDevice(handle);
1163 if (data == nullptr) {
1164 return errors::InvalidArgument(
1165 "Failed for find multi-device function handle ", handle);
1166 }
1167
1168 for (const auto& pair : data->glue_) {
1169 const ComponentFunctionData& comp_data = pair.second;
1170 DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1171 if (comp_data.ret_indices.empty()) {
1172 continue;
1173 }
1174
1175 const string& target = pair.first;
1176 FunctionLibraryRuntime* target_flr = GetFLR(target);
1177 Device* target_device = nullptr;
1178 Device* host = nullptr;
1179 if (target_flr == nullptr) {
1180 if (!data->has_remote_outputs) {
1181 data->has_remote_outputs = true;
1182 }
1183 target_device = device_set()->FindDeviceByName(target);
1184 string remote_host;
1185 TF_RETURN_IF_ERROR(
1186 DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1187 host = device_set()->FindDeviceByName(remote_host);
1188 } else {
1189 target_device = target_flr->device();
1190 }
1191 output_devices->resize(data->num_outputs_);
1192 for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1193 int ret_index = comp_data.ret_indices[j];
1194 if (data->ret_types_[ret_index] == DT_RESOURCE) {
1195 (*output_devices)[ret_index] = target_device;
1196 } else {
1197 (*output_devices)[ret_index] =
1198 comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1199 }
1200 }
1201 }
1202
1203 return OkStatus();
1204 }
1205
PrepareRunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const MultiDeviceFunctionData ** data) const1206 Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice(
1207 const FunctionLibraryRuntime::Options& opts,
1208 FunctionLibraryRuntime::Handle handle,
1209 const MultiDeviceFunctionData** data) const {
1210 if (opts.create_rendezvous) {
1211 // FLR->Run() is the default entry point. It checks for cancellation,
1212 // creates rendezvous, etc.
1213 // Letting create_rendezvous through will do the wrong thing - each
1214 // component function will get a separate rendezvous created by its FLR.
1215 return errors::Internal(
1216 "Cannot call ProcessFunctionLibraryRuntime::Run with "
1217 "create_rendezvous=true. Please run the function "
1218 "using FunctionLibraryRuntime::Run");
1219 }
1220
1221 *data = IsMultiDevice(handle);
1222 if (*data == nullptr) {
1223 return errors::NotFound("Multi-device function handle ", handle,
1224 "not found. Was the function instantiated?");
1225 }
1226
1227 // Check whether we have the right rendezvous.
1228 if (opts.rendezvous && (*data)->is_cross_process_ &&
1229 !opts.rendezvous->is_cross_process()) {
1230 return errors::InvalidArgument(
1231 "Running a cross process function ", (*data)->function_name_,
1232 " without an appropriate cross process Rendezvous.");
1233 }
1234
1235 return OkStatus();
1236 }
1237
GetOrderedSubgraphs(const MultiDeviceFunctionData * data) const1238 std::vector<string> ProcessFunctionLibraryRuntime::GetOrderedSubgraphs(
1239 const MultiDeviceFunctionData* data) const {
1240 std::vector<string> subgraph_keys;
1241 subgraph_keys.reserve(data->glue_.size());
1242 for (const auto& pair : data->glue_) {
1243 subgraph_keys.push_back(pair.first);
1244 }
1245 auto send_first_ordering = [&](const string& a, const string& b) {
1246 auto a_summary = data->glue_.at(a).async_attributes.summary();
1247 auto b_summary = data->glue_.at(b).async_attributes.summary();
1248 if (a_summary == b_summary) {
1249 return false;
1250 }
1251 if (a_summary == AsyncAttributes::kSendOnly) {
1252 return true;
1253 }
1254 return false;
1255 };
1256 std::sort(subgraph_keys.begin(), subgraph_keys.end(), send_first_ordering);
1257 return subgraph_keys;
1258 }
1259
RunMultiDeviceSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle outer_handle,std::vector<FunctionRet> * rets,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1260 Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync(
1261 const FunctionLibraryRuntime::Options& opts,
1262 FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1263 std::function<Status(const ComponentFunctionData& comp_data,
1264 InternalArgs* args)>
1265 get_component_args) const {
1266 const MultiDeviceFunctionData* data;
1267 Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1268 if (!prepare_status.ok()) {
1269 return prepare_status;
1270 }
1271
1272 FunctionLibraryRuntime::Options opts_copy = opts;
1273
1274 // Sort the subgraphs topologically before execution to avoid deadlock:
1275 //
1276 // Because subgraphs will not execute in parallel here, dependencies between
1277 // subgraphs cannot be resolved automatically. In contrast, with multi-
1278 // threaded execution, we launch all subgraphs at once, asynchronously, and
1279 // allow any to block mid-execution while its dependencies are resolved.
1280 //
1281 // In this synchronous execution path, currently supported ops with inter-
1282 // subgraph dependencies are send and receive. As `_Send` and `_HostSend`
1283 // are non-blocking, we run subgraphs with those first, and those with
1284 // the blocking '_Recv' and '_HostRecv' ops will have their dependencies
1285 // resolved before execution.
1286 //
1287 // We assume that the partitioning has a valid deadlock-free ordering and the
1288 // safety of running synchronously has already been confirmed by this point.
1289 std::vector<string> subgraph_keys = GetOrderedSubgraphs(data);
1290
1291 for (const string& target : subgraph_keys) {
1292 const ComponentFunctionData& comp_data = data->glue_.at(target);
1293 FunctionLibraryRuntime::Handle comp_handle = comp_data.handle;
1294
1295 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1296 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1297
1298 InternalArgs comp_args;
1299 Status args_status = get_component_args(comp_data, &comp_args);
1300 if (!args_status.ok()) {
1301 VLOG(2) << "Failed to get component function arguments: " << args_status;
1302 return args_status;
1303 }
1304 rets->resize(data->num_outputs_);
1305
1306 VLOG(1) << "Running component function on device " << target << " from "
1307 << data->function_name_ << " with handle " << comp_handle;
1308 FunctionLibraryRuntime* flr = GetFLR(target);
1309 if (flr != nullptr) {
1310 opts_copy.remote_execution = false;
1311 // When target device has private thread pool, use the target device
1312 // runner
1313 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1314 opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1315 VLOG(4) << " with " << opts_copy.DebugString();
1316
1317 std::vector<Tensor> comp_tensor_rets;
1318 Status run_status =
1319 flr->RunSync(opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1320 &comp_tensor_rets);
1321 if (!run_status.ok()) {
1322 VLOG(2) << "Component function execution failed: " << run_status;
1323 const string function_and_msg = strings::StrCat(
1324 errors::FormatFunctionForError(data->function_name_), " ",
1325 run_status.error_message());
1326 if (opts.rendezvous != nullptr) opts.rendezvous->StartAbort(run_status);
1327 return errors::CreateWithUpdatedMessage(run_status, function_and_msg);
1328 } else {
1329 VLOG(2) << "Component function execution succeeded.";
1330 for (int i = 0; i < comp_tensor_rets.size(); ++i) {
1331 (*rets)[comp_data.ret_indices[i]] = comp_tensor_rets[i];
1332 }
1333 }
1334 } else {
1335 // Fall back to DistributedFunctionLibraryRuntime for remote execution.
1336 opts_copy.remote_execution = true;
1337 VLOG(4) << " with " << opts_copy.DebugString();
1338
1339 std::vector<std::unique_ptr<CleanUpItem>> cleanup_items;
1340 Notification n;
1341 Status s;
1342 std::vector<FunctionRet> comp_rets;
1343 RunInternal(opts_copy, comp_handle, comp_args.args, &comp_rets,
1344 &cleanup_items, [&n, &s](const Status& status) {
1345 s.Update(status);
1346 n.Notify();
1347 });
1348 n.WaitForNotification();
1349 return s;
1350 }
1351 }
1352 return OkStatus();
1353 }
1354
RunMultiDeviceAsync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle outer_handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1355 void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync(
1356 const FunctionLibraryRuntime::Options& opts,
1357 FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1358 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1359 FunctionLibraryRuntime::DoneCallback done,
1360 std::function<Status(const ComponentFunctionData& comp_data,
1361 InternalArgs* args)>
1362 get_component_args) const {
1363 const MultiDeviceFunctionData* data;
1364 Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1365 if (!prepare_status.ok()) {
1366 done(prepare_status);
1367 return;
1368 }
1369
1370 // A locally created cancellation manager, used only when the caller does not
1371 // provide one in argument.
1372 std::shared_ptr<CancellationManager> local_cm;
1373 CancellationManager* cm = opts.cancellation_manager;
1374 if (cm == nullptr) {
1375 local_cm = std::make_shared<CancellationManager>();
1376 cm = local_cm.get();
1377 }
1378
1379 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1380 for (int i = 0; i < data->glue_.size(); ++i) {
1381 refcounted_done->Ref();
1382 }
1383
1384 FunctionLibraryRuntime::Options opts_copy = opts;
1385 for (const auto& pair : data->glue_) {
1386 const string& target = pair.first;
1387 const ComponentFunctionData& comp_data = pair.second;
1388 FunctionLibraryRuntime::Handle comp_handle = pair.second.handle;
1389
1390 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1391 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1392 opts_copy.cancellation_manager = cm;
1393
1394 InternalArgs comp_args;
1395 Status s = get_component_args(comp_data, &comp_args);
1396 if (!s.ok()) {
1397 VLOG(2) << "Failed to get component function arguments: " << s;
1398 refcounted_done->UpdateStatus(s);
1399 refcounted_done->Unref();
1400 cm->StartCancel();
1401 continue;
1402 }
1403 std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1404 rets->resize(data->num_outputs_);
1405
1406 auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1407 cm, local_cm, data, comp_handle,
1408 target](const Status& status) {
1409 if (!status.ok()) {
1410 VLOG(2) << "Component function execution on target " << target
1411 << " from " << data->function_name_ << " with handle "
1412 << comp_handle << " failed: " << status;
1413 const string function_and_msg = strings::StrCat(
1414 errors::FormatFunctionForError(data->function_name_), " ",
1415 status.error_message());
1416 refcounted_done->UpdateStatus(
1417 errors::CreateWithUpdatedMessage(status, function_and_msg));
1418 // Cancel the execution of other component functions.
1419 cm->StartCancel();
1420 } else {
1421 VLOG(2) << "Component function execution on target " << target
1422 << " from " << data->function_name_ << " with handle "
1423 << comp_handle << " succeeded.";
1424 for (int i = 0; i < comp_rets->size(); ++i) {
1425 (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1426 }
1427 }
1428 delete comp_rets;
1429 // refcounted_done is thread-safe
1430 refcounted_done->Unref();
1431 };
1432
1433 FunctionLibraryRuntime* flr = GetFLR(target);
1434 if (flr != nullptr) {
1435 opts_copy.remote_execution = false;
1436 // When target device has private thread pool, use the target device
1437 // runner
1438 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1439 opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1440
1441 VLOG(1) << "Running component function on device " << target << " from "
1442 << data->function_name_ << " with handle " << comp_handle;
1443 VLOG(4) << " with " << opts_copy.DebugString();
1444
1445 std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1446 flr->Run(
1447 opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1448 comp_tensor_rets,
1449 TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1450 std::move(component_fn_callback)));
1451 } else {
1452 opts_copy.remote_execution = true;
1453
1454 VLOG(1) << "Running component function on device " << target << " from "
1455 << data->function_name_ << " with handle " << comp_handle;
1456 VLOG(4) << " with " << opts_copy.DebugString();
1457
1458 RunInternal(opts_copy, comp_handle, comp_args.args, comp_rets,
1459 cleanup_items, std::move(component_fn_callback));
1460 }
1461 }
1462 refcounted_done->Unref();
1463 }
1464
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1465 Status ProcessFunctionLibraryRuntime::Instantiate(
1466 const string& function_name, AttrSlice attrs,
1467 const FunctionLibraryRuntime::InstantiateOptions& options,
1468 FunctionLibraryRuntime::Handle* handle) {
1469 if (options.is_multi_device_function) {
1470 return InstantiateMultiDevice(function_name, attrs, options, handle);
1471 }
1472
1473 *handle = kInvalidHandle;
1474 FunctionLibraryRuntime* flr = GetFLR(options.target);
1475 if (flr != nullptr) {
1476 return flr->Instantiate(function_name, attrs, options, handle);
1477 }
1478
1479 Status status;
1480 Notification notification;
1481 InstantiateRemote(function_name, attrs, options, handle,
1482 [&status, ¬ification](const Status& s) {
1483 status = s;
1484 notification.Notify();
1485 });
1486 notification.WaitForNotification();
1487 return status;
1488 }
1489
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1490 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1491 FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1492 tf_shared_lock l(mu_);
1493 const auto& mdevice_it = mdevice_data_.find(handle);
1494 if (mdevice_it != mdevice_data_.end()) {
1495 *is_cross_process = mdevice_it->second->is_cross_process_;
1496 return OkStatus();
1497 }
1498 const auto& it = function_data_.find(handle);
1499 if (it != function_data_.end()) {
1500 *is_cross_process = it->second->is_cross_process();
1501 return OkStatus();
1502 }
1503 return errors::InvalidArgument("Handle ", handle, " not found.");
1504 }
1505
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1506 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1507 const string& function_name, AttrSlice attrs,
1508 const FunctionLibraryRuntime::InstantiateOptions& options,
1509 FunctionLibraryRuntime::Handle* handle,
1510 FunctionLibraryRuntime::DoneCallback done) {
1511 if (parent_ == nullptr) {
1512 done(errors::Internal(
1513 "Currently don't support instantiating functions on device: ",
1514 options.target));
1515 return;
1516 }
1517 auto target = options.target;
1518 VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1519 string function_key = Canonicalize(function_name, attrs, options);
1520 FunctionData* f;
1521 {
1522 mutex_lock l(mu_);
1523 FunctionLibraryRuntime::Handle h =
1524 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1525 if (h == kInvalidHandle || function_data_.count(h) == 0) {
1526 h = AddHandleLocked(function_key, target, kInvalidHandle);
1527 }
1528 f = function_data_[h].get();
1529 *handle = h;
1530 }
1531 f->DistributedInit(
1532 parent_, function_name,
1533 options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1534 [this, function_name, target, handle, done](const Status& s) {
1535 VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1536 << " on: " << target << " with handle: " << *handle
1537 << " (this: " << this << ")";
1538 done(s);
1539 });
1540 }
1541
RemoveHandle(FunctionLibraryRuntime::Handle handle)1542 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1543 FunctionLibraryRuntime::Handle handle) {
1544 mutex_lock l(mu_);
1545 table_.erase(function_data_[handle]->function_key());
1546 function_data_.erase(handle);
1547 return OkStatus();
1548 }
1549
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1550 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1551 FunctionLibraryRuntime::Handle handle) {
1552 std::unique_ptr<MultiDeviceFunctionData> mdata;
1553 {
1554 mutex_lock l(mu_);
1555 auto it = mdevice_data_.find(handle);
1556 --it->second->instantiation_counter_;
1557 if (it->second->instantiation_counter_ != 0) {
1558 return OkStatus();
1559 }
1560 mdata = std::move(it->second);
1561 table_.erase(mdata->function_key_);
1562 mdevice_data_.erase(it);
1563 }
1564
1565 // If we are here we are releasing the last instantiation of `handle`.
1566 // Release all component function handles.
1567 Status overall_status;
1568 for (const auto& it : mdata->glue_) {
1569 const string& device = it.first;
1570 FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1571 FunctionLibraryRuntime* flr = GetFLR(device);
1572 if (flr == nullptr) {
1573 // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1574 // parent is not null.
1575 if (parent_ != nullptr) {
1576 return errors::Unimplemented(
1577 "Releasing a multi-device component handle on a remote device is "
1578 "not yet implemented.");
1579 }
1580 return errors::InvalidArgument(
1581 "Failed to find FunctionLibraryRuntime for device ", device,
1582 " when releasing multi-device function handle ", handle);
1583 }
1584 Status status = flr->ReleaseHandle(flr_handle);
1585 if (!status.ok()) {
1586 overall_status = status;
1587 }
1588 }
1589
1590 return overall_status;
1591 }
1592
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1593 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1594 FunctionLibraryRuntime::Handle handle) {
1595 // Return directly if all function handles has already been released.
1596 if (flr_map_ == nullptr) return OkStatus();
1597
1598 if (IsMultiDevice(handle)) {
1599 return ReleaseMultiDeviceHandle(handle);
1600 }
1601
1602 FunctionLibraryRuntime* flr = nullptr;
1603 string target_device;
1604 {
1605 mutex_lock l(mu_);
1606 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1607 target_device = function_data_[handle]->target_device();
1608 }
1609 flr = GetFLR(target_device);
1610 if (flr != nullptr) {
1611 return flr->ReleaseHandle(handle);
1612 }
1613 return errors::InvalidArgument("Handle not found: ", handle);
1614 }
1615
CleanupCreatedRendezvous(const Rendezvous * created_rendezvous,const int64_t step_id) const1616 void ProcessFunctionLibraryRuntime::CleanupCreatedRendezvous(
1617 const Rendezvous* created_rendezvous, const int64_t step_id) const {
1618 if (created_rendezvous) {
1619 DCHECK(rendezvous_factory_);
1620 created_rendezvous->Unref();
1621 Status s = rendezvous_factory_.CleanUp(step_id);
1622 if (!s.ok()) {
1623 LOG(ERROR) << s;
1624 }
1625 }
1626 }
1627
1628 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64_t step_id,const Rendezvous * created_rendezvous) const1629 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1630 std::vector<std::unique_ptr<CleanUpItem>>* items,
1631 FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
1632 const Rendezvous* created_rendezvous) const {
1633 return [this, items, done = std::move(done), step_id,
1634 created_rendezvous](const Status& status) {
1635 this->CleanupCreatedRendezvous(created_rendezvous, step_id);
1636 auto* local_status = new Status(status);
1637 CleanUp(items, [local_status, done](const Status& cleanup_status) {
1638 local_status->Update(cleanup_status);
1639 done(*local_status);
1640 delete local_status;
1641 });
1642 delete items;
1643 };
1644 }
1645
CreateRendezvous(FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1646 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1647 FunctionLibraryRuntime::Options& opts,
1648 Rendezvous** created_rendezvous) const {
1649 DCHECK(opts.rendezvous == nullptr);
1650 if (!rendezvous_factory_) {
1651 return errors::FailedPrecondition(
1652 "The caller does not provide a rendezvous and "
1653 "ProcessFunctionLibraryRuntime was created without a rendezvous "
1654 "factory.");
1655 }
1656 Status s = rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1657 if (s.ok()) {
1658 opts.rendezvous = *created_rendezvous;
1659 opts.create_rendezvous = false;
1660 }
1661 return s;
1662 }
1663
GetComponentArgs(const gtl::ArraySlice<Tensor> args,const ProcessFunctionLibraryRuntime::ComponentFunctionData & comp_data,ProcessFunctionLibraryRuntime::InternalArgs * comp_args)1664 Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1665 const gtl::ArraySlice<Tensor> args,
1666 const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1667 ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1668 // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1669 for (const auto& it : comp_data.arg_indices) {
1670 if (it.index >= args.size()) {
1671 return errors::InvalidArgument("index ", it.index,
1672 " is out of range [0, ", args.size(), ")");
1673 }
1674 if (it.sub_index >= 0) {
1675 const Tensor& t = args[it.index];
1676 if (t.dtype() != DT_RESOURCE) {
1677 return errors::InvalidArgument("Got unexpected sub_index ",
1678 it.sub_index, " for argument ",
1679 it.index);
1680 }
1681 const auto& handles = t.flat<ResourceHandle>();
1682 if (it.sub_index >= handles.size()) {
1683 return errors::InvalidArgument("Sub_index ", it.sub_index,
1684 "is out of range [0,", handles.size(),
1685 ") for argument ", it.index);
1686 }
1687 comp_args->args.push_back(Tensor(handles(it.sub_index)));
1688 } else {
1689 comp_args->args.push_back(args[it.index]);
1690 }
1691 }
1692 return OkStatus();
1693 }
1694
1695 #if !defined(IS_MOBILE_PLATFORM)
GetComponentArgs(const FunctionArgsInterface & args,const ProcessFunctionLibraryRuntime::ComponentFunctionData & comp_data,ProcessFunctionLibraryRuntime::InternalArgs * comp_args)1696 Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1697 const FunctionArgsInterface& args,
1698 const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1699 ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1700 for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1701 const FunctionArgIndex index = comp_data.arg_indices.at(i);
1702 Tensor tensor;
1703 if (args.GetLocalArg(index, &tensor).ok()) {
1704 comp_args->args.push_back(std::move(tensor));
1705 } else {
1706 eager::RemoteTensorHandle remote_handle;
1707 TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1708 comp_args->remote_args.emplace_back(
1709 std::make_unique<eager::RemoteTensorHandle>(
1710 std::move(remote_handle)));
1711 comp_args->args.push_back(comp_args->remote_args.back().get());
1712 }
1713 }
1714 return OkStatus();
1715 }
1716 #endif // IS_MOBILE_PLATFORM
1717
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1718 void ProcessFunctionLibraryRuntime::Run(
1719 const FunctionLibraryRuntime::Options& opts,
1720 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1721 std::vector<Tensor>* rets,
1722 FunctionLibraryRuntime::DoneCallback done) const {
1723 FunctionLibraryRuntime::Options new_opts = opts;
1724 Rendezvous* created_rendezvous = nullptr;
1725 if (!opts.rendezvous) {
1726 Status s = CreateRendezvous(new_opts, &created_rendezvous);
1727 if (!s.ok()) {
1728 done(s);
1729 return;
1730 }
1731 }
1732
1733 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1734 done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1735 new_opts.step_id, created_rendezvous);
1736 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1737 done = [rets, function_rets, done = std::move(done)](const Status& s) {
1738 Status status = s;
1739 if (status.ok()) {
1740 status.Update(FunctionRetsToTensors(function_rets, rets));
1741 }
1742 delete function_rets;
1743 done(status);
1744 };
1745 bool multi_device = HasMultiDeviceHandle(handle);
1746 if (multi_device) {
1747 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1748 InternalArgs* comp_args) -> Status {
1749 return GetComponentArgs(args, comp_data, comp_args);
1750 };
1751 return RunMultiDeviceAsync(new_opts, handle, function_rets, cleanup_items,
1752 std::move(done), std::move(get_component_args));
1753 }
1754 std::vector<FunctionArg> local_args;
1755 for (const auto& tensor : args) {
1756 local_args.push_back(tensor);
1757 }
1758 RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1759 std::move(done));
1760 }
1761
1762 // This method handles the simple remote call case (not multi-device).
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1763 void ProcessFunctionLibraryRuntime::RunInternal(
1764 const FunctionLibraryRuntime::Options& opts,
1765 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1766 std::vector<FunctionRet>* rets,
1767 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1768 FunctionLibraryRuntime::DoneCallback done) const {
1769 FunctionLibraryRuntime* flr = nullptr;
1770 string target_device;
1771 FunctionLibraryRuntime::LocalHandle local_handle;
1772 {
1773 tf_shared_lock l(mu_);
1774 auto iter = function_data_.find(handle);
1775 if (iter == function_data_.end()) {
1776 done(errors::NotFound("Handle: ", handle, " not found."));
1777 return;
1778 }
1779 FunctionData* function_data = iter->second.get();
1780 target_device = function_data->target_device();
1781 local_handle = function_data->local_handle();
1782 }
1783
1784 if (!opts.remote_execution) {
1785 done(
1786 errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1787 "only be called for multi-device functions or "
1788 "for remote execution."));
1789 return;
1790 }
1791
1792 flr = GetFLR(target_device);
1793 if (flr != nullptr) {
1794 auto rendezvous = opts.rendezvous;
1795 string source_device = opts.source_device;
1796 DeviceContext* device_context;
1797 Status s = GetDeviceContext(source_device, &device_context);
1798 if (!s.ok()) {
1799 done(s);
1800 return;
1801 }
1802 int64_t src_incarnation, target_incarnation;
1803 s = GetDeviceIncarnation(source_device, &src_incarnation);
1804 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1805 if (!s.ok()) {
1806 done(s);
1807 return;
1808 }
1809
1810 std::vector<Tensor> local_args = GetLocalArgs(args);
1811
1812 // Send the args over to the target device.
1813 s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1814 local_args, device_context, opts.args_alloc_attrs,
1815 rendezvous);
1816 if (!s.ok()) {
1817 done(s);
1818 return;
1819 }
1820 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1821 opts.rets_alloc_attrs;
1822 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1823 flr->Run(opts, handle, local_args, remote_rets,
1824 [source_device, target_device, target_incarnation, rendezvous,
1825 device_context, rets_alloc_attrs, remote_rets, rets,
1826 done = std::move(done)](const Status& status) mutable {
1827 if (!status.ok()) {
1828 delete remote_rets;
1829 done(status);
1830 return;
1831 }
1832 int64_t num_returns = remote_rets->size();
1833 delete remote_rets;
1834 // Now receive the return values from the target.
1835 std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1836 ReceiveTensorsAsync(target_device, source_device, "ret_",
1837 target_incarnation, num_returns,
1838 device_context, rets_alloc_attrs, rendezvous,
1839 recv_tensors,
1840 TensorsToFunctionRetsDoneCallback(
1841 rets, recv_tensors, std::move(done)));
1842 });
1843 return;
1844 }
1845 if (parent_ != nullptr) {
1846 auto cleanup_item = std::make_unique<CleanUpItem>();
1847 cleanup_item->device = target_device;
1848 cleanup_item->step_id = opts.step_id;
1849 cleanup_item->local_handle = local_handle;
1850 cleanup_items->emplace_back(std::move(cleanup_item));
1851 parent_->Run(opts, local_handle, args, rets, std::move(done));
1852 return;
1853 }
1854 done(errors::Internal("Could not find device"));
1855 }
1856
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1857 void ProcessFunctionLibraryRuntime::Run(
1858 const FunctionLibraryRuntime::Options& opts,
1859 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1860 FunctionLibraryRuntime::DoneCallback done) const {
1861 std::vector<Tensor> args;
1862 args.reserve(frame->num_args());
1863 for (size_t i = 0; i < frame->num_args(); ++i) {
1864 const Tensor* arg;
1865 Status s = frame->GetArg(i, &arg);
1866 args.emplace_back(*arg);
1867 if (!s.ok()) {
1868 done(s);
1869 }
1870 }
1871 std::vector<Tensor>* rets = new std::vector<Tensor>;
1872 rets->reserve(frame->num_retvals());
1873
1874 Run(opts, handle, args, rets,
1875
1876 [frame, rets, done = std::move(done)](const Status& status) {
1877 std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1878
1879 if (!status.ok()) {
1880 done(status);
1881 return;
1882 }
1883
1884 if (rets->size() != frame->num_retvals()) {
1885 done(errors::Internal(
1886 "Number of return values from function (", rets->size(),
1887 ") did not match expected number of return values (",
1888 frame->num_retvals(), ")."));
1889 return;
1890 }
1891
1892 for (size_t i = 0; i < frame->num_retvals(); ++i) {
1893 Status s = frame->SetRetval(i, (*rets)[i]);
1894 if (!s.ok()) {
1895 done(s);
1896 return;
1897 }
1898 }
1899 done(OkStatus());
1900 });
1901 }
1902
RunSync(const FunctionLibraryRuntime::Options & orig_opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1903 Status ProcessFunctionLibraryRuntime::RunSync(
1904 const FunctionLibraryRuntime::Options& orig_opts,
1905 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1906 std::vector<Tensor>* rets) const {
1907 MultiDeviceFunctionData* multi_device_data = IsMultiDevice(handle);
1908 if (multi_device_data && multi_device_data->enable_sync_execution) {
1909 metrics::IncrementTestCounter("pflr_runsync", "sync");
1910 FunctionLibraryRuntime::Options new_opts = orig_opts;
1911 Rendezvous* created_rendezvous = nullptr;
1912 if (!new_opts.rendezvous) {
1913 TF_RETURN_IF_ERROR(CreateRendezvous(new_opts, &created_rendezvous));
1914 }
1915
1916 std::vector<FunctionRet> function_rets;
1917 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1918 InternalArgs* comp_args) {
1919 return GetComponentArgs(args, comp_data, comp_args);
1920 };
1921
1922 Status status = RunMultiDeviceSync(new_opts, handle, &function_rets,
1923 std::move(get_component_args));
1924 CleanupCreatedRendezvous(created_rendezvous, new_opts.step_id);
1925 status.Update(FunctionRetsToTensors(&function_rets, rets));
1926 return status;
1927 } else {
1928 // TODO(b/207484417): Either handle or avoid/delete this fallback path.
1929 metrics::IncrementTestCounter("pflr_runsync", "async");
1930 Notification n;
1931 Status s;
1932 Run(orig_opts, handle, args, rets, [&n, &s](const Status& status) {
1933 s.Update(status);
1934 n.Notify();
1935 });
1936 n.WaitForNotification();
1937 return s;
1938 }
1939 }
1940
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1941 Status ProcessFunctionLibraryRuntime::RunSync(
1942 const FunctionLibraryRuntime::Options& opts,
1943 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1944 // TODO(b/207485199): Implement this as synchronous code.
1945 Notification n;
1946 Status s;
1947 Run(opts, handle, frame, [&n, &s](const Status& status) {
1948 s.Update(status);
1949 n.Notify();
1950 });
1951 n.WaitForNotification();
1952 return s;
1953 }
1954
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1955 void ProcessFunctionLibraryRuntime::Run(
1956 const FunctionLibraryRuntime::Options& opts,
1957 FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1958 std::vector<FunctionRet>* rets,
1959 FunctionLibraryRuntime::DoneCallback done) const {
1960 bool has_remote_outputs = false;
1961 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1962 if (data != nullptr) {
1963 has_remote_outputs = data->has_remote_outputs;
1964 }
1965 if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1966 const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1967 std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1968 return Run(
1969 opts, handle, local_inputs, tensor_rets,
1970 TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1971 }
1972
1973 FunctionLibraryRuntime::Options new_opts = opts;
1974 Rendezvous* created_rendezvous = nullptr;
1975 if (!opts.rendezvous) {
1976 Status s = CreateRendezvous(new_opts, &created_rendezvous);
1977 if (!s.ok()) {
1978 done(s);
1979 return;
1980 }
1981 }
1982
1983 #if defined(IS_MOBILE_PLATFORM)
1984 done(errors::Unimplemented(
1985 "Remote inputs are not available on mobile devices."));
1986 return;
1987 #else // !IS_MOBILE_PLATFORM
1988 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1989 done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1990 created_rendezvous);
1991
1992 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1993 InternalArgs* comp_args) -> Status {
1994 return GetComponentArgs(args, comp_data, comp_args);
1995 };
1996 return RunMultiDeviceAsync(new_opts, handle, rets, cleanup_items,
1997 std::move(done), std::move(get_component_args));
1998 #endif // !IS_MOBILE_PLATFORM
1999 }
2000
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const2001 void ProcessFunctionLibraryRuntime::CleanUp(
2002 std::vector<std::unique_ptr<CleanUpItem>>* items,
2003 FunctionLibraryRuntime::DoneCallback done) const {
2004 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
2005 for (auto& item : *items) {
2006 refcounted_done->Ref();
2007 auto* flr = GetFLR(item->device);
2008 if (flr != nullptr) {
2009 // TODO(fishx): cleanup state for local execution.
2010 refcounted_done->UpdateStatus(
2011 errors::Internal("Cleanup items shouldn't contain local item."));
2012 refcounted_done->Unref();
2013 } else if (parent_ != nullptr) {
2014 parent_->CleanUp(item->step_id, item->local_handle,
2015 [refcounted_done](const Status& status) {
2016 if (!status.ok()) {
2017 refcounted_done->UpdateStatus(status);
2018 }
2019 // refcounted_done is thread-safe
2020 refcounted_done->Unref();
2021 });
2022 } else {
2023 refcounted_done->UpdateStatus(
2024 errors::Internal("Could not find device in cleanup."));
2025 refcounted_done->Unref();
2026 }
2027 }
2028 refcounted_done->Unref();
2029 }
2030
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const2031 Status ProcessFunctionLibraryRuntime::Clone(
2032 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
2033 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
2034 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
2035 bool skip_flib_def) const {
2036 if (skip_flib_def) {
2037 *out_lib_def = std::make_unique<FunctionLibraryDefinition>(
2038 lib_def_->default_registry(), FunctionDefLibrary{});
2039 } else {
2040 *out_lib_def = std::make_unique<FunctionLibraryDefinition>(*lib_def_);
2041 }
2042 *out_pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
2043 device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
2044 out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
2045 session_metadata_, rendezvous_factory_);
2046 {
2047 tf_shared_lock l(mu_);
2048 for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
2049 }
2050 return OkStatus();
2051 }
2052
2053 } // namespace tensorflow
2054