• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <algorithm>
18 #include <iterator>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/common_runtime/device_set.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/optimization_registry.h"
29 #include "tensorflow/core/common_runtime/partitioning_utils.h"
30 #include "tensorflow/core/common_runtime/placer.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
33 #include "tensorflow/core/common_runtime/rendezvous_util.h"
34 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
35 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
36 #include "tensorflow/core/framework/cancellation.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/graph/graph_partition.h"
47 #include "tensorflow/core/lib/core/blocking_counter.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/gtl/inlined_vector.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/random/random.h"
53 #include "tensorflow/core/platform/notification.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/dump_graph.h"
56 #include "tensorflow/core/util/ptr_util.h"
57 #include "tensorflow/core/util/reffed_status_callback.h"
58 #if !defined(IS_MOBILE_PLATFORM)
59 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
60 #endif  // IS_MOBILE_PLATFORM
61 
62 namespace tensorflow {
63 
64 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
65 
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)66 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
67     DistributedFunctionLibraryRuntime* parent, const string& function_name,
68     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
69     const FunctionLibraryRuntime::InstantiateOptions& options,
70     FunctionLibraryRuntime::DoneCallback done) {
71   {
72     mutex_lock l(mu_);
73     is_cross_process_ = true;
74     if (init_started_) {
75       init_done_.WaitForNotification();
76       done(init_result_);
77       return;
78     }
79     init_started_ = true;
80   }
81   parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
82                       [this, done](const Status& s) {
83                         init_done_.Notify();
84                         done(s);
85                       });
86 }
87 
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)88 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
89     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
90     int graph_def_version, const FunctionLibraryDefinition* lib_def,
91     const OptimizerOptions& optimizer_options,
92     thread::ThreadPool* default_thread_pool,
93     DistributedFunctionLibraryRuntime* parent,
94     const SessionMetadata* session_metadata,
95     Rendezvous::Factory rendezvous_factory)
96     : parent_(parent),
97       env_(env),
98       config_(config ? absl::make_optional(*config) : absl::nullopt),
99       device_mgr_(device_mgr),
100       lib_def_(lib_def),
101       default_thread_pool_(default_thread_pool),
102       flr_map_(new std::unordered_map<Device*,
103                                       std::unique_ptr<FunctionLibraryRuntime>>),
104       next_handle_(0),
105       session_metadata_(session_metadata),
106       rendezvous_factory_(std::move(rendezvous_factory)),
107       optimizer_options_(optimizer_options),
108       graph_def_version_(graph_def_version) {
109   if (device_mgr == nullptr) {
110     (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
111         nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
112         graph_def_version, lib_def_, default_thread_pool, optimizer_options,
113         session_metadata_, this);
114     return;
115   }
116   InitializeDeviceAndFlr();
117 }
118 
119 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)120 Status ProcessFunctionLibraryRuntime::SendTensors(
121     const string& source_device, const string& target_device,
122     const string& key_prefix, int64_t src_incarnation,
123     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
124     const std::vector<AllocatorAttributes>& alloc_attrs,
125     RendezvousInterface* rendezvous) {
126   std::vector<string> keys;
127   for (int i = 0; i < tensors_to_send.size(); ++i) {
128     string name = strings::StrCat(key_prefix, i);
129     string key = Rendezvous::CreateKey(source_device, src_incarnation,
130                                        target_device, name, FrameAndIter(0, 0));
131     keys.push_back(key);
132   }
133   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
134       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
135   return OkStatus();
136 }
137 
138 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,int64_t num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)139 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
140     const string& source_device, const string& target_device,
141     const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
142     DeviceContext* device_context,
143     const std::vector<AllocatorAttributes>& alloc_attrs,
144     RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
145     StatusCallback done) {
146   std::vector<string> keys;
147   for (int64_t i = 0; i < num_tensors; ++i) {
148     string name = strings::StrCat(key_prefix, i);
149     string key = Rendezvous::CreateKey(source_device, src_incarnation,
150                                        target_device, name, FrameAndIter(0, 0));
151     keys.push_back(key);
152   }
153   RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
154                                  received_tensors, std::move(done));
155 }
156 
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)157 Status ProcessFunctionLibraryRuntime::GetRetTypes(
158     FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
159   FunctionLibraryRuntime* flr = nullptr;
160   {
161     tf_shared_lock l(mu_);
162     auto miter = mdevice_data_.find(h);
163     if (miter != mdevice_data_.end()) {
164       *ret_types = miter->second->ret_types_;
165       return OkStatus();
166     }
167     auto fiter = function_data_.find(h);
168     if (fiter != function_data_.end()) {
169       flr = GetFLR(fiter->second->target_device());
170     }
171   }
172   if (flr != nullptr) {
173     return flr->GetRetTypes(h, ret_types);
174   }
175   return errors::InvalidArgument("Handle ", h, " not found.");
176 }
177 
GetDeviceIncarnation(const string & device_name,int64_t * incarnation) const178 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
179     const string& device_name, int64_t* incarnation) const {
180   FunctionLibraryRuntime* flr = GetFLR(device_name);
181   if (flr == nullptr) {
182     return errors::InvalidArgument("Device name: ", device_name, " not found.");
183   }
184   *incarnation = flr->device()->attributes().incarnation();
185   return OkStatus();
186 }
187 
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const188 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
189     const string& device_name, DeviceContext** device_context) const {
190   *device_context = nullptr;
191   FunctionLibraryRuntime* flr = GetFLR(device_name);
192   if (flr == nullptr) {
193     return errors::InvalidArgument("Device name: ", device_name, " not found.");
194   }
195   Device* device = flr->device();
196   string device_type = device->parsed_name().type;
197   if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
198     // "TPU_SYSTEM" indicates that `device` is a CPU.
199     return OkStatus();
200   }
201 
202   if (device->IsRemoteCallAllowed()) {
203     auto* dev_info = flr->device()->tensorflow_accelerator_device_info();
204     if (dev_info) {
205       *device_context = dev_info->default_context;
206       return OkStatus();
207     }
208   }
209 
210   return errors::Internal("Device type: ", device_type,
211                           " is currently unsupported for remote ",
212                           "function executions");
213 }
214 
InitializeDeviceAndFlr()215 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
216   // Reset device_set_ by one of the two following scenarios:
217   // 1) Both cluster-FLR and its remote_device_mgr is available: include local
218   //    devices (if any) from the local device_mgr_ as Device type, and include
219   //    remote devices from cluster's remote_device_mgr as RemoteDevice type.
220   // 2) Include local devices from the local device_mgr_.
221   // In both scenarios, no device is added more than one times.
222   mutex_lock l(mu_);
223   device_set_ = std::make_shared<DeviceSet>();
224   if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
225     for (auto d : parent_->remote_device_mgr()->ListDevices()) {
226       Device* device = nullptr;
227       if (device_mgr_->LookupDevice(d->name(), &device) == OkStatus()) {
228         // If this device exists in device_mgr, i.e., a local device,
229         // add this device from the instance included in device_mgr_
230         device_set_->AddDevice(device);
231       } else {
232         device_set_->AddDevice(d);
233       }
234     }
235   } else {
236     for (auto d : device_mgr_->ListDevices()) {
237       device_set_->AddDevice(d);
238     }
239   }
240 
241   // Update flr_map_ by adding new devices
242   for (Device* d : device_mgr_->ListDevices()) {
243     if ((*flr_map_)[d] == nullptr) {
244       (*flr_map_)[d] = NewFunctionLibraryRuntime(
245           device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
246           graph_def_version_, lib_def_, default_thread_pool_,
247           optimizer_options_, session_metadata_, this);
248     }
249   }
250 }
251 
GetFLR(const string & device_name) const252 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
253     const string& device_name) const {
254   Device* device = nullptr;
255   if (device_name != kDefaultFLRDevice) {
256     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
257       VLOG(4) << "Could not find device: " << device_name;
258       return nullptr;
259     }
260   }
261   const auto& iter = flr_map_->find(device);
262   if (iter == flr_map_->end()) {
263     VLOG(1) << "Could not find device: " << device_name
264             << "in the local process.";
265     return nullptr;
266   }
267   return iter->second.get();
268 }
269 
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)270 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
271     const string& function_key, const string& device_name,
272     FunctionLibraryRuntime::LocalHandle local_handle) {
273   mutex_lock l(mu_);
274   return AddHandleLocked(function_key, device_name, local_handle);
275 }
276 
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)277 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
278     const string& function_key, const string& device_name,
279     FunctionLibraryRuntime::LocalHandle local_handle) {
280   auto h = next_handle_;
281   function_data_[h] =
282       std::make_unique<FunctionData>(device_name, local_handle, function_key);
283   table_[function_key] = h;
284   next_handle_++;
285   return h;
286 }
287 
288 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)289 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
290     std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
291   mutex_lock l(mu_);
292   auto h = next_handle_;
293   mdevice_data_[h] = std::move(data);
294   table_[function_key] = h;
295   next_handle_++;
296   return h;
297 }
298 
HasMultiDeviceHandle(FunctionLibraryRuntime::Handle handle) const299 bool ProcessFunctionLibraryRuntime::HasMultiDeviceHandle(
300     FunctionLibraryRuntime::Handle handle) const {
301   bool multi_device;
302   {
303     tf_shared_lock l(mu_);
304     multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
305   }
306   return multi_device;
307 }
308 
GetHandle(const string & function_key) const309 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
310     const string& function_key) const {
311   tf_shared_lock l(mu_);
312   return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
313 }
314 
315 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const316 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
317     const string& device_name, FunctionLibraryRuntime::Handle handle,
318     bool include_multi_device) const {
319   tf_shared_lock l(mu_);
320 
321   auto miter = mdevice_data_.find(handle);
322   if (miter != mdevice_data_.end()) {
323     if (!include_multi_device) return kInvalidLocalHandle;
324 
325     const MultiDeviceFunctionData& data = *miter->second;
326     if (data.glue_.size() != 1) return kInvalidLocalHandle;
327 
328     const auto& pair = *data.glue_.begin();
329     const string& func_device_name = pair.first;
330     const ComponentFunctionData& component_data = pair.second;
331     if (func_device_name != device_name) return kInvalidLocalHandle;
332 
333     // Replace the given handle with the handle for the single component
334     // function.
335     handle = component_data.handle;
336   }
337 
338   auto iter = function_data_.find(handle);
339   if (iter == function_data_.end()) {
340     return kInvalidLocalHandle;
341   }
342   FunctionData* function_data = iter->second.get();
343   if (function_data->target_device() != device_name) {
344     return kInvalidLocalHandle;
345   }
346   return function_data->local_handle();
347 }
348 
GetDeviceName(FunctionLibraryRuntime::Handle handle) const349 string ProcessFunctionLibraryRuntime::GetDeviceName(
350     FunctionLibraryRuntime::Handle handle) const {
351   tf_shared_lock l(mu_);
352   auto iter = function_data_.find(handle);
353   CHECK(iter != function_data_.end());
354   FunctionData* function_data = iter->second.get();
355   return function_data->target_device();
356 }
357 
358 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const359 ProcessFunctionLibraryRuntime::IsMultiDevice(
360     FunctionLibraryRuntime::Handle handle) const {
361   tf_shared_lock l(mu_);
362   const auto& it = mdevice_data_.find(handle);
363   if (it != mdevice_data_.end()) {
364     return it->second.get();
365   }
366   return nullptr;
367 }
368 
369 namespace {
370 // Sets `group` to the first colocation group specified in `node`. If no
371 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)372 void GetColocationGroup(const Node* node, string* group) {
373   // We hoist the conversion from C-style string literal to string here,
374   // so that we can avoid the many repeated calls to strlen().
375   static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
376   const AttrValue* attr_value =
377       node->attrs().Find(kColocationAttrNameStringPiece);
378   if (attr_value != nullptr && attr_value->has_list() &&
379       attr_value->list().s_size() > 0) {
380     *group = attr_value->list().s(0);
381   }
382 }
383 
AssignedOrRequestedDeviceName(const Node & node)384 const string* AssignedOrRequestedDeviceName(const Node& node) {
385   if (node.has_assigned_device_name()) {
386     return &node.assigned_device_name();
387   }
388   return &node.requested_device();
389 }
390 
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)391 Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>&
392                        input_resource_dtypes_and_shapes,
393                    const std::vector<Node*>& arg_nodes) {
394   for (Node* n : arg_nodes) {
395     int index;
396     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
397     DataType dtype;
398     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
399     if (dtype == DT_RESOURCE) {
400       auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
401       if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
402         AttrValue dtype_attr_value;
403         dtype_attr_value.mutable_list()->add_type(
404             dtype_and_shape_iter->second.dtype);
405         n->AddAttr("_handle_dtypes", dtype_attr_value);
406         TensorShapeProto shape_proto;
407         dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
408         AttrValue shape_attr_value;
409         *shape_attr_value.mutable_list()->add_shape() = shape_proto;
410         n->AddAttr("_handle_shapes", shape_attr_value);
411       }
412     }
413   }
414   return OkStatus();
415 }
416 
417 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)418 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
419   std::vector<Tensor> tensors;
420   for (const auto& arg : args) {
421     if (arg.index() == 0) {
422       tensors.push_back(absl::get<Tensor>(arg));
423     }
424   }
425   return tensors;
426 }
427 
428 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)429 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
430     std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
431     FunctionLibraryRuntime::DoneCallback done) {
432   return [rets, tensors, done = std::move(done)](const Status& s) {
433     if (s.ok()) {
434       for (const auto& t : *tensors) {
435         rets->push_back(t);
436       }
437     }
438     delete tensors;
439     done(s);
440   };
441 }
442 
443 // Push Tensors in `function_rets` into `tensors`.
FunctionRetsToTensors(const std::vector<FunctionRet> * function_rets,std::vector<Tensor> * tensors)444 Status FunctionRetsToTensors(const std::vector<FunctionRet>* function_rets,
445                              std::vector<Tensor>* tensors) {
446   for (const auto& ret : *function_rets) {
447     if (ret.index() != 0) {
448       return errors::Internal(
449           "Expect a Tensor as a function output but got a TensorShape.");
450     }
451     tensors->push_back(absl::get<Tensor>(ret));
452   }
453   return OkStatus();
454 }
455 
456 }  // anonymous namespace
457 
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,const FunctionLibraryDefinition * lib_def,Device * default_device)458 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
459     const std::vector<string>& input_devices,
460     const std::vector<string>& output_devices, const DeviceSet& device_set,
461     const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
462     const FunctionLibraryDefinition* lib_def, Device* default_device) {
463   // If output_devices are not specified, we want to set the output device
464   // based on the device of the output producing node. The output producing
465   // node can be an arg node because functions can simply return their
466   // arguments. To make sure that the output producing nodes have assigned
467   // devices, we assign them to arguments first.
468   for (Node* node : arg_nodes) {
469     const AttrValue* attr_value;
470     TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
471     int64_t index = attr_value->i();
472     node->set_assigned_device_name(input_devices[index]);
473   }
474 
475   for (Node* node : ret_nodes) {
476     if (output_devices.empty()) {
477       DataType dtype;
478       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
479 
480       VLOG(3) << "Trying to determine device for node " << node->name()
481               << "[T=" << DataTypeString(dtype) << "]";
482 
483       // If output_devices are empty, the node producing retval
484       // must have explicitly assigned device or a colocation constraint
485       // to a node with explicitly assigned device.
486       for (const auto& it : node->in_edges()) {
487         if (it->IsControlEdge()) continue;
488 
489         Node* src_node = it->src();
490         const string* src_device = AssignedOrRequestedDeviceName(*src_node);
491         string colocation_group = "";
492         GetColocationGroup(src_node, &colocation_group);
493         VLOG(3) << "Considering src: " << src_node->name()
494                 << " src_device: " << *src_device
495                 << " colo group: " << colocation_group;
496         while (src_device->empty() && colocation_group.empty() &&
497                src_node->IsIdentity()) {
498           // Only follows the real data input of Identity, not control edges.
499           Node* input_node;
500           TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
501           src_node = input_node;
502 
503           src_device = AssignedOrRequestedDeviceName(*src_node);
504           GetColocationGroup(src_node, &colocation_group);
505           VLOG(3) << "Considering src: " << src_node->name()
506                   << " src_device: " << *src_device
507                   << " colo group: " << colocation_group;
508         }
509 
510         // If resource is produced by a function call node, we can't trust
511         // source node device assignment, because multi-device functions can
512         // return resource placed on multiple devices. In such case we leave
513         // retval device assignment empty, and rely on placer to infer correct
514         // assignment based on actual output device.
515         const bool can_use_src_node_device =
516             !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node));
517 
518         if (!colocation_group.empty()) {
519           AttrValue::ListValue colo_attr;
520           colo_attr.add_s(colocation_group);
521           std::vector<string> colo_slice = {colocation_group};
522           node->AddAttr(kColocationAttrName, colo_slice);
523         } else if (!src_device->empty() && can_use_src_node_device) {
524           // Do not copy device from src node for variants, unless it is a no-op
525           // forward from input to output. This gets handled in
526           // colocation_graph.cc which has special logic for correctly placing
527           // _Retvals for various variant types.
528           if (dtype == DT_VARIANT && !src_node->IsArg()) {
529             continue;
530           }
531           // src_device can be a partially specified device. Find the
532           // matching device in the device_set.
533           DeviceNameUtils::ParsedName parsed;
534           if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
535             return errors::InvalidArgument(
536                 "Failed to parse explicit device specification ", *src_device);
537           }
538           std::vector<Device*> matching_devices;
539           device_set.FindMatchingDevices(parsed, &matching_devices);
540           if (matching_devices.empty()) {
541             if (default_device != nullptr) {
542               matching_devices.push_back(default_device);
543             } else {
544               return errors::InvalidArgument(
545                   "Unable to find any devices for spec ", *src_device);
546             }
547           } else if (matching_devices.size() != 1) {
548             bool on_same_task = true;
549             for (int i = 1; i < matching_devices.size(); ++i) {
550               if (!DeviceNameUtils::IsSameAddressSpace(
551                       matching_devices.at(0)->parsed_name(),
552                       matching_devices.at(i)->parsed_name())) {
553                 on_same_task = false;
554                 break;
555               }
556             }
557             // If the src node of an output is assigned to a address space (e.g.
558             // py_func), rely on placer to assign a device to the output.
559             if (on_same_task) {
560               continue;
561             }
562             // Compare with default_device if it has a narrower scope matching
563             // requested device.
564             if (default_device != nullptr) {
565               int colocated_on_default_device = 0;
566               for (int i = 0; i < matching_devices.size(); ++i) {
567                 if (DeviceNameUtils::IsSameAddressSpace(
568                         default_device->parsed_name(),
569                         matching_devices.at(i)->parsed_name())) {
570                   colocated_on_default_device++;
571                 }
572               }
573               // Continue to raise error if multiple colocated devices are
574               // found.
575               if (colocated_on_default_device == 1) {
576                 continue;
577               }
578             }
579             // Convert a vector of devices to a string.
580             // Using absl::StrJoin did not work in Android builds.
581             string devices = "[";
582             for (Device* device : matching_devices) {
583               devices.append(device->name());
584               devices.append(", ");
585             }
586             if (devices.size() > 2) {
587               devices.resize(devices.size() - 2);
588             }
589             devices.append("]");
590 
591             return errors::InvalidArgument(
592                 *src_device,
593                 "When FunctionLibraryRuntime::Options.output_devices are "
594                 "not specified for a multi-device function, the device "
595                 "specification on the output node must match exactly one "
596                 "device. Matched devices are ",
597                 devices);
598           }
599           VLOG(3) << "Setting output device to " << matching_devices[0]->name()
600                   << " for node " << SummarizeNode(*node);
601           node->set_assigned_device_name(matching_devices[0]->name());
602         } else if (!src_device->empty() && !can_use_src_node_device) {
603           VLOG(3) << "Did not set device for a resource output node "
604                   << SummarizeNode(*node);
605         }
606       }
607     } else {
608       const AttrValue* attr_value;
609       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
610       int64_t index = attr_value->i();
611       // output_devices size is checked in InstantiateMultiDevice
612       DCHECK_GT(output_devices.size(), index);
613       VLOG(3) << "Setting output device to " << output_devices[index]
614               << " for return at index " << index;
615       node->set_assigned_device_name(output_devices[index]);
616     }
617   }
618   return OkStatus();
619 }
620 
621 namespace {
622 
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)623 Status ValidateNoListArguments(
624     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
625     const string& function_name) {
626   for (const OpDef::ArgDef& arg : args) {
627     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
628       return errors::InvalidArgument(
629           "Function ", function_name, " has an ", arg_type, " named \"",
630           arg.name(),
631           "\" that is a list of tensors."
632           " Multi-device functions support only single-tensor inputs "
633           " and outputs");
634     }
635   }
636   return OkStatus();
637 }
638 
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)639 Status ValidateMultiDeviceOptions(
640     const FunctionDef& fdef,
641     const FunctionLibraryRuntime::InstantiateOptions& options) {
642   const OpDef& signature = fdef.signature();
643   // Multi-device functions currently do not support list inputs or outputs.
644   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
645                                              signature.name()));
646   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
647                                              signature.name()));
648   if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
649       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
650     return errors::Unimplemented(
651         "Function '", signature.name(), "' has `",
652         FunctionLibraryDefinition::kIntsOnDeviceAttr,
653         "` attribute set. This attribute is not currently supported by "
654         "multi-device functions.");
655   }
656   if (options.input_devices.size() != signature.input_arg_size()) {
657     return errors::InvalidArgument(
658         "InstantiateOptions.input_devices must have the same length "
659         "as the number of arguments: input_devices length = ",
660         options.input_devices.size(),
661         " number of arguments = ", signature.input_arg_size());
662   }
663   if (!options.output_devices.empty() &&
664       options.output_devices.size() != signature.output_arg_size()) {
665     return errors::InvalidArgument(
666         "InstantiateOptions.output_devices must either be empty or have the "
667         "same length as the number of arguments: output_devices length = ",
668         options.output_devices.size(),
669         " number of arguments = ", signature.output_arg_size());
670   }
671   return OkStatus();
672 }
673 
674 }  // anonymous namespace
675 
676 ProcessFunctionLibraryRuntime::AsyncAttributes::Summary
Summarize(const Graph * graph)677 ProcessFunctionLibraryRuntime::AsyncAttributes::Summarize(const Graph* graph) {
678   bool has_send_op = false;
679   bool has_recv_op = false;
680   bool has_unsafe_op = false;
681   for (const Node* node : graph->nodes()) {
682     if (node->IsSend() || node->IsHostSend()) {
683       has_send_op = true;
684     }
685     if (node->IsRecv() || node->IsHostRecv()) {
686       has_recv_op = true;
687     }
688     if (!ValidateOpIsSafeForSyncExecution(*node,
689                                           allow_control_flow_sync_execution())
690              .ok()) {
691       has_unsafe_op = true;
692     }
693   }
694   // (1) Anything completely unsupported?
695   if (has_unsafe_op) {
696     metrics::IncrementTestCounter("subgraph_async_summary", "unsafe_op");
697     return AsyncAttributes::kAsyncRequired;
698   }
699   // (2) That only leaves send/recv.  If neither, then it's safe.
700   if (!has_send_op && !has_recv_op) {
701     metrics::IncrementTestCounter("subgraph_async_summary", "safe_for_sync");
702     return AsyncAttributes::kSafeForSync;
703   }
704   // (3) If each subgraph has only send or only recv, then it's possible to
705   // order them to run sequentially without deadlock.
706   if (has_send_op && !has_recv_op) {
707     metrics::IncrementTestCounter("subgraph_async_summary", "send_only");
708     return AsyncAttributes::kSendOnly;
709   }
710   if (has_recv_op && !has_send_op) {
711     metrics::IncrementTestCounter("subgraph_async_summary", "recv_only");
712     return AsyncAttributes::kRecvOnly;
713   }
714   // Otherwise, assume it's unsupported.
715   metrics::IncrementTestCounter("subgraph_async_summary", "other");
716   return AsyncAttributes::kAsyncRequired;
717 }
718 
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)719 Status GetGraphAndArgRets(
720     const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
721     const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
722     std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
723     std::vector<string>* ret_node_names, DataTypeVector* ret_types,
724     std::vector<string>* control_ret_node_names) {
725   std::unique_ptr<FunctionBody> fbody;
726   // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
727   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
728   if (!fbody) {
729     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
730     return errors::Internal("Failed to construct FunctionBody for ",
731                             function_name);
732   }
733   *graph = std::unique_ptr<Graph>(fbody->graph);
734   arg_nodes->reserve(fbody->arg_nodes.size());
735   std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
736             std::back_inserter(*arg_nodes));
737   ret_nodes->reserve(fbody->ret_nodes.size());
738   std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
739             std::back_inserter(*ret_nodes));
740   fbody->graph = nullptr;
741   ret_node_names->reserve(fbody->ret_nodes.size());
742   for (const Node* node : fbody->ret_nodes) {
743     ret_node_names->push_back(node->name());
744   }
745   for (const auto& ret_type : fbody->ret_types) {
746     ret_types->push_back(ret_type);
747   }
748   control_ret_node_names->reserve(fbody->control_ret_nodes.size());
749   for (const Node* node : fbody->control_ret_nodes) {
750     control_ret_node_names->push_back(node->name());
751   }
752   return OkStatus();
753 }
754 
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)755 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
756     const string& function_name, AttrSlice attrs,
757     const FunctionLibraryRuntime::InstantiateOptions& options,
758     FunctionLibraryRuntime::Handle* handle) {
759   // Check if this function has already been instantiated.
760   const string& function_key = Canonicalize(function_name, attrs, options);
761 
762   {
763     mutex_lock l(mu_);
764     const auto& it = table_.find(function_key);
765     if (it != table_.end()) {
766       *handle = it->second;
767       ++mdevice_data_[*handle]->instantiation_counter_;
768       return OkStatus();
769     }
770   }
771 
772   VLOG(1) << "Instantiating MultiDevice function \"" << function_name
773           << "\" on default device \"" << options.target << "\"";
774   if (VLOG_IS_ON(3)) {
775     int index = 0;
776     VLOG(3) << "Requested input devices:";
777     for (const string& device : options.input_devices) {
778       VLOG(3) << "    [input " << index++ << "] " << device;
779     }
780     index = 0;
781     VLOG(3) << "Requested output devices:";
782     for (const string& device : options.output_devices) {
783       VLOG(3) << "    [output " << index++ << "] " << device;
784     }
785   }
786 
787   const FunctionLibraryDefinition* lib_def =
788       options.lib_def == nullptr ? lib_def_ : options.lib_def;
789 
790   const FunctionDef* fdef = lib_def->Find(function_name);
791   if (fdef == nullptr) {
792     return errors::InvalidArgument("Failed to find function \"", function_name,
793                                    "\" in function library: ", lib_def);
794   }
795 
796   TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
797 
798   std::unique_ptr<Graph> graph;
799   std::vector<Node*> arg_nodes, ret_nodes;
800   std::vector<string> ret_node_names;
801   DataTypeVector ret_types;
802   std::vector<string> control_ret_node_names;
803 
804   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
805       function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
806       &ret_node_names, &ret_types, &control_ret_node_names));
807 
808   GraphDef graph_def;
809   graph->ToGraphDef(&graph_def);
810   FunctionLibraryDefinition reachable_lib_def =
811       lib_def->ReachableDefinitions(graph_def);
812   *graph_def.mutable_library() = reachable_lib_def.ToProto();
813   if (options.graph_collector != nullptr) {
814     options.graph_collector->CollectRawGraph(graph_def);
815   }
816 
817   Device* default_device = nullptr;
818   if (options.default_device_to_target && !options.target.empty()) {
819     // Make the `target` device the default device if nothing else is hard
820     // coded. This allows the same function definition to be specialized to
821     // different devices depending on the `PartitionedCallOp` device.
822     FunctionLibraryRuntime* flr = GetFLR(options.target);
823     if (flr == nullptr) {
824       return errors::InvalidArgument(
825           "Cannot instantiate multi-device function with target device ",
826           options.target);
827     }
828     default_device = flr->device();
829   }
830 
831   // Mark each node in the graph to be compiled by specified device.
832   if (!options.xla_compile_device_type.empty()) {
833     for (Node* node : graph->op_nodes()) {
834       node->AddAttr("_xla_compile_device_type",
835                     options.xla_compile_device_type);
836     }
837   }
838 
839   const std::shared_ptr<DeviceSet> dev_set = device_set();
840 
841   TF_RETURN_IF_ERROR(
842       SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
843   TF_RETURN_IF_ERROR(PinArgsAndRets(
844       options.input_devices, options.output_devices, *dev_set, arg_nodes,
845       ret_nodes, lib_def_,
846       options.config_proto.allow_soft_placement() ? default_device : nullptr));
847 
848   auto data = std::make_unique<MultiDeviceFunctionData>(
849       function_name, function_key, ret_node_names.size(),
850       std::move(reachable_lib_def), std::move(ret_types));
851 
852   // The runtime shouldn't depend on duplication between the function library
853   // owned by the graph and the one owned by the runtime. To ensure this, for
854   // now we ensure that the graph function library is empty and the runtime
855   // library receives the query from LookUps on the graph function library.
856   graph->mutable_flib_def()->set_default_registry(&data->lib_def_);
857   graph->mutable_flib_def()->Clear();
858 
859   // Do not run function/graph optimization passes for component functions,
860   // since they have already processed the main function.
861   const bool should_run_optimization_passes = !options.is_component_function;
862   if (!should_run_optimization_passes) {
863     VLOG(1) << "Skipping function/graph optimization passes when instantiating "
864                "component function "
865             << function_name;
866   }
867 
868   // Mapping from a function body node name to the control output name.
869   std::unordered_map<string, string> node_name_to_control_ret;
870 
871   bool control_rets_updated = false;
872   if (should_run_optimization_passes) {
873     TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
874         *dev_set, options.config_proto, &graph, &data->lib_def_,
875         &control_ret_node_names, &control_rets_updated));
876   }
877 
878   if (control_rets_updated) {
879     // Function graph pass may have resulted in different nodes/node names for
880     // control rets.
881     for (const auto& control_ret : control_ret_node_names) {
882       node_name_to_control_ret.emplace(control_ret, control_ret);
883     }
884   } else {
885     for (const auto& control_ret : fdef->control_ret()) {
886       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
887     }
888   }
889 
890   GraphOptimizationPassOptions optimization_options;
891   // TODO(iga): Thread other relevant options from SessionOptions.
892   SessionOptions session_options;
893   session_options.env = env_;
894   session_options.config = options.config_proto;
895   optimization_options.session_options = &session_options;
896   optimization_options.graph = &graph;
897   optimization_options.flib_def = &data->lib_def_;
898   optimization_options.device_set = dev_set.get();
899   optimization_options.is_function_graph = true;
900   std::vector<CompositeDevice*> composite_devices;
901   {
902     tf_shared_lock l(mu_);
903     for (auto* d : composite_devices_) composite_devices.push_back(d);
904   }
905   optimization_options.composite_devices = &composite_devices;
906   optimization_options.default_function_device = default_device;
907   optimization_options.function_def = fdef;
908   optimization_options.shape_inference_on_tfe_dialect_import =
909       options.shape_inference_on_tfe_dialect_import;
910 
911   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
912   if (should_run_optimization_passes) {
913     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
914         OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
915   }
916 
917   // TODO(b/124993244): Smartly merge options in nested defuns, and raise
918   // exceptions/warnings in case where nested function call options are ignored.
919   DumpGraph("Before calling Placer", graph.get());
920   Placer placer(graph.get(), function_name, optimization_options.flib_def,
921                 dev_set.get(), default_device,
922                 options.config_proto.allow_soft_placement(),
923                 options.config_proto.log_device_placement());
924   TF_RETURN_IF_ERROR(placer.Run());
925 
926   DumpGraph("Before running POST_PLACEMENT passes", graph.get());
927   if (should_run_optimization_passes) {
928     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
929         OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
930   }
931 
932   Device* cpu_device;
933   TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
934 
935   if (options.optimize_graph_fn) {
936     DumpGraph("Before running graph optimization fn", graph.get());
937     Status status = options.optimize_graph_fn(
938         std::move(ret_node_names), std::move(control_ret_node_names),
939         &data->lib_def_, *dev_set, cpu_device, &graph);
940     if (!status.ok()) {
941       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
942                    << status.ToString();
943     }
944     DumpGraph("After optimization", graph.get());
945   }
946 
947   DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
948   if (should_run_optimization_passes) {
949     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
950         OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
951   }
952 
953   // Expand the nodes assigned to a CompositeDevice before graph partition to
954   // avoid generating a subgraph on a virtual device for execution.
955   // This transformation should happen as late as possible, in order to run as
956   // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
957   // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
958   TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
959       options.composite_devices, graph.get()));
960 
961   if (options.graph_collector != nullptr) {
962     GraphDef def;
963     graph->ToGraphDef(&def);
964     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
965     options.graph_collector->CollectOptimizedGraph(def);
966   }
967 
968   VLOG(4) << "Main function graph to be partitioned:";
969   VLOG(4) << DebugString(graph->ToGraphDefDebug());
970 
971   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
972   TF_RETURN_IF_ERROR(
973       PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
974 
975   for (const auto& pair : subgraphs) {
976     DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
977                               pair.first, ")"),
978               pair.second.get());
979   }
980   optimization_options.graph = nullptr;
981   optimization_options.device_set = nullptr;
982   optimization_options.partition_graphs = &subgraphs;
983   // Normally POST_PARTITIONING passes are run by distributed workers.
984   // Distributed workers are currently not supported in this code path, so we
985   // run the passes here.
986   if (should_run_optimization_passes) {
987     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
988         OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
989   }
990   for (const auto& pair : subgraphs) {
991     const auto* optimized_subgraph = pair.second.get();
992     DumpGraph(
993         strings::StrCat("After all optimization passes (", pair.first, ")"),
994         optimized_subgraph);
995     if (VLOG_IS_ON(1)) {
996       DumpGraphDefToFile(
997           strings::StrCat("pflr_after_all_optimization_passes_",
998                           reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
999                           pair.first),
1000           optimized_subgraph->ToGraphDefDebug());
1001     }
1002   }
1003 
1004   if (options.graph_collector != nullptr) {
1005     for (const auto& pair : subgraphs) {
1006       GraphDef def;
1007       pair.second->ToGraphDef(&def);
1008       *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
1009       options.graph_collector->CollectPartitionedGraph(def);
1010     }
1011   }
1012 
1013   // We must preserve control returns in each of the function components,
1014   // otherwise after function inlining we might prune side-effectful nodes.
1015   const auto control_ret =
1016       [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
1017     const auto it = node_name_to_control_ret.find(n->name());
1018     return it != node_name_to_control_ret.end()
1019                ? absl::make_optional<string>(it->second)
1020                : absl::nullopt;
1021   };
1022 
1023   int i = 0;
1024   // Generate a random function_name to avoid one function reuse the partition
1025   // function instantiated by another function.
1026   FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
1027   FunctionNameGenerator name_generator(
1028       data_lib_def, absl::StrCat(function_name, "_", random::New64()));
1029   auto num_subgraphs = subgraphs.size();
1030   gtl::InlinedVector<Status, 4> instantiate_status(num_subgraphs);
1031   BlockingCounter counter(static_cast<int>(num_subgraphs));
1032   auto runner = [this, num_subgraphs](std::function<void()> fn) {
1033     // NOTE: Only use thread pool to instantiate sub-function when there are
1034     // more than 8 sub-functions. We want to avoid cost of switching thread when
1035     // there are only a few sub-functions.
1036     if (default_thread_pool_ != nullptr && num_subgraphs > 8) {
1037       default_thread_pool_->Schedule(fn);
1038     } else {
1039       fn();
1040     }
1041   };
1042 
1043   // Before instantiating component functions, determine synchronous execution.
1044   data->enable_sync_execution = false;
1045   if (options.allow_small_function_optimizations) {
1046     data->enable_sync_execution = true;
1047     for (const auto& pair : subgraphs) {
1048       ComponentFunctionData* comp_data = &data->glue_[pair.first];
1049       const Graph* subgraph = pair.second.get();
1050       comp_data->async_attributes =
1051           AsyncAttributes(subgraph, options.allow_control_flow_sync_execution);
1052       if (comp_data->async_attributes.summary() ==
1053           AsyncAttributes::kAsyncRequired) {
1054         data->enable_sync_execution = false;
1055       }
1056     }
1057   }
1058 
1059   // Instantiate each component function (subgraph).
1060   for (const auto& pair : subgraphs) {
1061     Status* status = &instantiate_status[i];
1062     string unique_name = name_generator.GetName();
1063     ComponentFunctionData* comp_data = &data->glue_[pair.first];
1064     runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
1065             &control_ret, &options, status, &counter, &data] {
1066       const string& target = pair.first;
1067 
1068       const string& device_type =
1069           dev_set->FindDeviceByName(target)->device_type();
1070       Graph* subgraph = pair.second.get();
1071 
1072       bool ints_on_device =
1073           (device_type == "TPU" || device_type == "XLA_CPU" ||
1074            device_type == "XLA_GPU" || options.int_args_and_retvals_on_device);
1075       status->Update(UpdateArgAndRetvalMetadata(
1076           subgraph, &comp_data->arg_indices, &comp_data->ret_indices,
1077           &comp_data->arg_alloc_attrs, &comp_data->ret_alloc_attrs,
1078           ints_on_device));
1079       if (!status->ok()) {
1080         counter.DecrementCount();
1081         return;
1082       }
1083       FunctionDef shard;
1084       status->Update(
1085           GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
1086       if (!status->ok()) {
1087         counter.DecrementCount();
1088         return;
1089       }
1090       status->Update(data_lib_def->AddFunctionDef(shard));
1091       if (!status->ok()) {
1092         counter.DecrementCount();
1093         return;
1094       }
1095       FunctionLibraryRuntime::InstantiateOptions opts;
1096       opts.executor_type = options.executor_type;
1097       opts.target = target;
1098       opts.lib_def = data_lib_def;
1099       opts.create_kernels_eagerly = options.create_kernels_eagerly;
1100       opts.state_handle = options.state_handle;
1101       opts.allow_small_function_optimizations = data->enable_sync_execution;
1102       opts.allow_control_flow_sync_execution =
1103           options.allow_control_flow_sync_execution;
1104       AttrValue ints_on_device_attr;
1105       ints_on_device_attr.set_b(options.int_args_and_retvals_on_device);
1106       shard.mutable_attr()->insert(
1107           {FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr});
1108       auto attrs = AttrSlice(&shard.attr());
1109       VLOG(1) << "Start instantiating component function " << unique_name
1110               << " on device " << target;
1111       VLOG(4) << DebugString(shard);
1112 
1113       auto* component_handle = new FunctionLibraryRuntime::Handle;
1114       auto done = [this, status, unique_name, comp_data, component_handle,
1115                    &data, &counter](const Status& s) {
1116         status->Update(s);
1117 
1118         VLOG(1) << "Finished instantiating component function " << unique_name
1119                 << " with handle " << *component_handle << " status: " << s;
1120         if (status->ok()) {
1121           {
1122             mutex_lock l(mu_);
1123             if (function_data_[*component_handle]->is_cross_process()) {
1124               data->is_cross_process_ = true;
1125             }
1126           }
1127           comp_data->handle = *component_handle;
1128         }
1129         delete component_handle;
1130         counter.DecrementCount();
1131       };
1132 
1133       FunctionLibraryRuntime* flr = GetFLR(opts.target);
1134       if (flr != nullptr) {
1135         // Initialize local function synchronously.
1136         Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
1137         done(s);
1138       } else {
1139         opts.ret_indices = comp_data->ret_indices;
1140         // Initialize remote function asynchronously.
1141         InstantiateRemote(unique_name, attrs, opts, component_handle, done);
1142       }
1143     });
1144     i += 1;
1145   }
1146   counter.Wait();
1147   StatusGroup group;
1148   for (auto& status : instantiate_status) {
1149     group.Update(status);
1150   }
1151   TF_RETURN_IF_ERROR(group.as_summary_status());
1152 
1153   *handle = AddMultiDeviceHandle(std::move(data), function_key);
1154   VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1155           << "\" with handle " << *handle;
1156   return OkStatus();
1157 }
1158 
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1159 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1160     FunctionLibraryRuntime::Handle handle,
1161     std::vector<Device*>* output_devices) const {
1162   MultiDeviceFunctionData* data = IsMultiDevice(handle);
1163   if (data == nullptr) {
1164     return errors::InvalidArgument(
1165         "Failed for find multi-device function handle ", handle);
1166   }
1167 
1168   for (const auto& pair : data->glue_) {
1169     const ComponentFunctionData& comp_data = pair.second;
1170     DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1171     if (comp_data.ret_indices.empty()) {
1172       continue;
1173     }
1174 
1175     const string& target = pair.first;
1176     FunctionLibraryRuntime* target_flr = GetFLR(target);
1177     Device* target_device = nullptr;
1178     Device* host = nullptr;
1179     if (target_flr == nullptr) {
1180       if (!data->has_remote_outputs) {
1181         data->has_remote_outputs = true;
1182       }
1183       target_device = device_set()->FindDeviceByName(target);
1184       string remote_host;
1185       TF_RETURN_IF_ERROR(
1186           DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1187       host = device_set()->FindDeviceByName(remote_host);
1188     } else {
1189       target_device = target_flr->device();
1190     }
1191     output_devices->resize(data->num_outputs_);
1192     for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1193       int ret_index = comp_data.ret_indices[j];
1194       if (data->ret_types_[ret_index] == DT_RESOURCE) {
1195         (*output_devices)[ret_index] = target_device;
1196       } else {
1197         (*output_devices)[ret_index] =
1198             comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1199       }
1200     }
1201   }
1202 
1203   return OkStatus();
1204 }
1205 
PrepareRunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const MultiDeviceFunctionData ** data) const1206 Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice(
1207     const FunctionLibraryRuntime::Options& opts,
1208     FunctionLibraryRuntime::Handle handle,
1209     const MultiDeviceFunctionData** data) const {
1210   if (opts.create_rendezvous) {
1211     // FLR->Run() is the default entry point. It checks for cancellation,
1212     // creates rendezvous, etc.
1213     // Letting create_rendezvous through will do the wrong thing - each
1214     // component function will get a separate rendezvous created by its FLR.
1215     return errors::Internal(
1216         "Cannot call ProcessFunctionLibraryRuntime::Run with "
1217         "create_rendezvous=true. Please run the function "
1218         "using FunctionLibraryRuntime::Run");
1219   }
1220 
1221   *data = IsMultiDevice(handle);
1222   if (*data == nullptr) {
1223     return errors::NotFound("Multi-device function handle ", handle,
1224                             "not found. Was the function instantiated?");
1225   }
1226 
1227   // Check whether we have the right rendezvous.
1228   if (opts.rendezvous && (*data)->is_cross_process_ &&
1229       !opts.rendezvous->is_cross_process()) {
1230     return errors::InvalidArgument(
1231         "Running a cross process function ", (*data)->function_name_,
1232         " without an appropriate cross process Rendezvous.");
1233   }
1234 
1235   return OkStatus();
1236 }
1237 
GetOrderedSubgraphs(const MultiDeviceFunctionData * data) const1238 std::vector<string> ProcessFunctionLibraryRuntime::GetOrderedSubgraphs(
1239     const MultiDeviceFunctionData* data) const {
1240   std::vector<string> subgraph_keys;
1241   subgraph_keys.reserve(data->glue_.size());
1242   for (const auto& pair : data->glue_) {
1243     subgraph_keys.push_back(pair.first);
1244   }
1245   auto send_first_ordering = [&](const string& a, const string& b) {
1246     auto a_summary = data->glue_.at(a).async_attributes.summary();
1247     auto b_summary = data->glue_.at(b).async_attributes.summary();
1248     if (a_summary == b_summary) {
1249       return false;
1250     }
1251     if (a_summary == AsyncAttributes::kSendOnly) {
1252       return true;
1253     }
1254     return false;
1255   };
1256   std::sort(subgraph_keys.begin(), subgraph_keys.end(), send_first_ordering);
1257   return subgraph_keys;
1258 }
1259 
RunMultiDeviceSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle outer_handle,std::vector<FunctionRet> * rets,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1260 Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync(
1261     const FunctionLibraryRuntime::Options& opts,
1262     FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1263     std::function<Status(const ComponentFunctionData& comp_data,
1264                          InternalArgs* args)>
1265         get_component_args) const {
1266   const MultiDeviceFunctionData* data;
1267   Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1268   if (!prepare_status.ok()) {
1269     return prepare_status;
1270   }
1271 
1272   FunctionLibraryRuntime::Options opts_copy = opts;
1273 
1274   // Sort the subgraphs topologically before execution to avoid deadlock:
1275   //
1276   // Because subgraphs will not execute in parallel here, dependencies between
1277   // subgraphs cannot be resolved automatically. In contrast, with multi-
1278   // threaded execution, we launch all subgraphs at once, asynchronously, and
1279   // allow any to block mid-execution while its dependencies are resolved.
1280   //
1281   // In this synchronous execution path, currently supported ops with inter-
1282   // subgraph dependencies are send and receive.  As `_Send` and `_HostSend`
1283   // are non-blocking, we run subgraphs with those first, and those with
1284   // the blocking '_Recv' and '_HostRecv' ops will have their dependencies
1285   // resolved before execution.
1286   //
1287   // We assume that the partitioning has a valid deadlock-free ordering and the
1288   // safety of running synchronously has already been confirmed by this point.
1289   std::vector<string> subgraph_keys = GetOrderedSubgraphs(data);
1290 
1291   for (const string& target : subgraph_keys) {
1292     const ComponentFunctionData& comp_data = data->glue_.at(target);
1293     FunctionLibraryRuntime::Handle comp_handle = comp_data.handle;
1294 
1295     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1296     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1297 
1298     InternalArgs comp_args;
1299     Status args_status = get_component_args(comp_data, &comp_args);
1300     if (!args_status.ok()) {
1301       VLOG(2) << "Failed to get component function arguments: " << args_status;
1302       return args_status;
1303     }
1304     rets->resize(data->num_outputs_);
1305 
1306     VLOG(1) << "Running component function on device " << target << " from "
1307             << data->function_name_ << " with handle " << comp_handle;
1308     FunctionLibraryRuntime* flr = GetFLR(target);
1309     if (flr != nullptr) {
1310       opts_copy.remote_execution = false;
1311       // When target device has private thread pool, use the target device
1312       // runner
1313       thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1314       opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1315       VLOG(4) << "    with " << opts_copy.DebugString();
1316 
1317       std::vector<Tensor> comp_tensor_rets;
1318       Status run_status =
1319           flr->RunSync(opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1320                        &comp_tensor_rets);
1321       if (!run_status.ok()) {
1322         VLOG(2) << "Component function execution failed: " << run_status;
1323         const string function_and_msg = strings::StrCat(
1324             errors::FormatFunctionForError(data->function_name_), " ",
1325             run_status.error_message());
1326         if (opts.rendezvous != nullptr) opts.rendezvous->StartAbort(run_status);
1327         return errors::CreateWithUpdatedMessage(run_status, function_and_msg);
1328       } else {
1329         VLOG(2) << "Component function execution succeeded.";
1330         for (int i = 0; i < comp_tensor_rets.size(); ++i) {
1331           (*rets)[comp_data.ret_indices[i]] = comp_tensor_rets[i];
1332         }
1333       }
1334     } else {
1335       // Fall back to DistributedFunctionLibraryRuntime for remote execution.
1336       opts_copy.remote_execution = true;
1337       VLOG(4) << "    with " << opts_copy.DebugString();
1338 
1339       std::vector<std::unique_ptr<CleanUpItem>> cleanup_items;
1340       Notification n;
1341       Status s;
1342       std::vector<FunctionRet> comp_rets;
1343       RunInternal(opts_copy, comp_handle, comp_args.args, &comp_rets,
1344                   &cleanup_items, [&n, &s](const Status& status) {
1345                     s.Update(status);
1346                     n.Notify();
1347                   });
1348       n.WaitForNotification();
1349       return s;
1350     }
1351   }
1352   return OkStatus();
1353 }
1354 
RunMultiDeviceAsync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle outer_handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1355 void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync(
1356     const FunctionLibraryRuntime::Options& opts,
1357     FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1358     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1359     FunctionLibraryRuntime::DoneCallback done,
1360     std::function<Status(const ComponentFunctionData& comp_data,
1361                          InternalArgs* args)>
1362         get_component_args) const {
1363   const MultiDeviceFunctionData* data;
1364   Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1365   if (!prepare_status.ok()) {
1366     done(prepare_status);
1367     return;
1368   }
1369 
1370   // A locally created cancellation manager, used only when the caller does not
1371   // provide one in argument.
1372   std::shared_ptr<CancellationManager> local_cm;
1373   CancellationManager* cm = opts.cancellation_manager;
1374   if (cm == nullptr) {
1375     local_cm = std::make_shared<CancellationManager>();
1376     cm = local_cm.get();
1377   }
1378 
1379   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1380   for (int i = 0; i < data->glue_.size(); ++i) {
1381     refcounted_done->Ref();
1382   }
1383 
1384   FunctionLibraryRuntime::Options opts_copy = opts;
1385   for (const auto& pair : data->glue_) {
1386     const string& target = pair.first;
1387     const ComponentFunctionData& comp_data = pair.second;
1388     FunctionLibraryRuntime::Handle comp_handle = pair.second.handle;
1389 
1390     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1391     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1392     opts_copy.cancellation_manager = cm;
1393 
1394     InternalArgs comp_args;
1395     Status s = get_component_args(comp_data, &comp_args);
1396     if (!s.ok()) {
1397       VLOG(2) << "Failed to get component function arguments: " << s;
1398       refcounted_done->UpdateStatus(s);
1399       refcounted_done->Unref();
1400       cm->StartCancel();
1401       continue;
1402     }
1403     std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1404     rets->resize(data->num_outputs_);
1405 
1406     auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1407                                   cm, local_cm, data, comp_handle,
1408                                   target](const Status& status) {
1409       if (!status.ok()) {
1410         VLOG(2) << "Component function execution on target " << target
1411                 << " from " << data->function_name_ << " with handle "
1412                 << comp_handle << " failed: " << status;
1413         const string function_and_msg = strings::StrCat(
1414             errors::FormatFunctionForError(data->function_name_), " ",
1415             status.error_message());
1416         refcounted_done->UpdateStatus(
1417             errors::CreateWithUpdatedMessage(status, function_and_msg));
1418         // Cancel the execution of other component functions.
1419         cm->StartCancel();
1420       } else {
1421         VLOG(2) << "Component function execution on target " << target
1422                 << " from " << data->function_name_ << " with handle "
1423                 << comp_handle << " succeeded.";
1424         for (int i = 0; i < comp_rets->size(); ++i) {
1425           (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1426         }
1427       }
1428       delete comp_rets;
1429       // refcounted_done is thread-safe
1430       refcounted_done->Unref();
1431     };
1432 
1433     FunctionLibraryRuntime* flr = GetFLR(target);
1434     if (flr != nullptr) {
1435       opts_copy.remote_execution = false;
1436       // When target device has private thread pool, use the target device
1437       // runner
1438       thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1439       opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1440 
1441       VLOG(1) << "Running component function on device " << target << " from "
1442               << data->function_name_ << " with handle " << comp_handle;
1443       VLOG(4) << "    with " << opts_copy.DebugString();
1444 
1445       std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1446       flr->Run(
1447           opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1448           comp_tensor_rets,
1449           TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1450                                             std::move(component_fn_callback)));
1451     } else {
1452       opts_copy.remote_execution = true;
1453 
1454       VLOG(1) << "Running component function on device " << target << " from "
1455               << data->function_name_ << " with handle " << comp_handle;
1456       VLOG(4) << "    with " << opts_copy.DebugString();
1457 
1458       RunInternal(opts_copy, comp_handle, comp_args.args, comp_rets,
1459                   cleanup_items, std::move(component_fn_callback));
1460     }
1461   }
1462   refcounted_done->Unref();
1463 }
1464 
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1465 Status ProcessFunctionLibraryRuntime::Instantiate(
1466     const string& function_name, AttrSlice attrs,
1467     const FunctionLibraryRuntime::InstantiateOptions& options,
1468     FunctionLibraryRuntime::Handle* handle) {
1469   if (options.is_multi_device_function) {
1470     return InstantiateMultiDevice(function_name, attrs, options, handle);
1471   }
1472 
1473   *handle = kInvalidHandle;
1474   FunctionLibraryRuntime* flr = GetFLR(options.target);
1475   if (flr != nullptr) {
1476     return flr->Instantiate(function_name, attrs, options, handle);
1477   }
1478 
1479   Status status;
1480   Notification notification;
1481   InstantiateRemote(function_name, attrs, options, handle,
1482                     [&status, &notification](const Status& s) {
1483                       status = s;
1484                       notification.Notify();
1485                     });
1486   notification.WaitForNotification();
1487   return status;
1488 }
1489 
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1490 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1491     FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1492   tf_shared_lock l(mu_);
1493   const auto& mdevice_it = mdevice_data_.find(handle);
1494   if (mdevice_it != mdevice_data_.end()) {
1495     *is_cross_process = mdevice_it->second->is_cross_process_;
1496     return OkStatus();
1497   }
1498   const auto& it = function_data_.find(handle);
1499   if (it != function_data_.end()) {
1500     *is_cross_process = it->second->is_cross_process();
1501     return OkStatus();
1502   }
1503   return errors::InvalidArgument("Handle ", handle, " not found.");
1504 }
1505 
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1506 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1507     const string& function_name, AttrSlice attrs,
1508     const FunctionLibraryRuntime::InstantiateOptions& options,
1509     FunctionLibraryRuntime::Handle* handle,
1510     FunctionLibraryRuntime::DoneCallback done) {
1511   if (parent_ == nullptr) {
1512     done(errors::Internal(
1513         "Currently don't support instantiating functions on device: ",
1514         options.target));
1515     return;
1516   }
1517   auto target = options.target;
1518   VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1519   string function_key = Canonicalize(function_name, attrs, options);
1520   FunctionData* f;
1521   {
1522     mutex_lock l(mu_);
1523     FunctionLibraryRuntime::Handle h =
1524         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1525     if (h == kInvalidHandle || function_data_.count(h) == 0) {
1526       h = AddHandleLocked(function_key, target, kInvalidHandle);
1527     }
1528     f = function_data_[h].get();
1529     *handle = h;
1530   }
1531   f->DistributedInit(
1532       parent_, function_name,
1533       options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1534       [this, function_name, target, handle, done](const Status& s) {
1535         VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1536                 << " on: " << target << " with handle: " << *handle
1537                 << " (this: " << this << ")";
1538         done(s);
1539       });
1540 }
1541 
RemoveHandle(FunctionLibraryRuntime::Handle handle)1542 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1543     FunctionLibraryRuntime::Handle handle) {
1544   mutex_lock l(mu_);
1545   table_.erase(function_data_[handle]->function_key());
1546   function_data_.erase(handle);
1547   return OkStatus();
1548 }
1549 
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1550 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1551     FunctionLibraryRuntime::Handle handle) {
1552   std::unique_ptr<MultiDeviceFunctionData> mdata;
1553   {
1554     mutex_lock l(mu_);
1555     auto it = mdevice_data_.find(handle);
1556     --it->second->instantiation_counter_;
1557     if (it->second->instantiation_counter_ != 0) {
1558       return OkStatus();
1559     }
1560     mdata = std::move(it->second);
1561     table_.erase(mdata->function_key_);
1562     mdevice_data_.erase(it);
1563   }
1564 
1565   // If we are here we are releasing the last instantiation of `handle`.
1566   // Release all component function handles.
1567   Status overall_status;
1568   for (const auto& it : mdata->glue_) {
1569     const string& device = it.first;
1570     FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1571     FunctionLibraryRuntime* flr = GetFLR(device);
1572     if (flr == nullptr) {
1573       // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1574       // parent is not null.
1575       if (parent_ != nullptr) {
1576         return errors::Unimplemented(
1577             "Releasing a multi-device component handle on a remote device is "
1578             "not yet implemented.");
1579       }
1580       return errors::InvalidArgument(
1581           "Failed to find FunctionLibraryRuntime for device ", device,
1582           " when releasing multi-device function handle ", handle);
1583     }
1584     Status status = flr->ReleaseHandle(flr_handle);
1585     if (!status.ok()) {
1586       overall_status = status;
1587     }
1588   }
1589 
1590   return overall_status;
1591 }
1592 
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1593 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1594     FunctionLibraryRuntime::Handle handle) {
1595   // Return directly if all function handles has already been released.
1596   if (flr_map_ == nullptr) return OkStatus();
1597 
1598   if (IsMultiDevice(handle)) {
1599     return ReleaseMultiDeviceHandle(handle);
1600   }
1601 
1602   FunctionLibraryRuntime* flr = nullptr;
1603   string target_device;
1604   {
1605     mutex_lock l(mu_);
1606     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1607     target_device = function_data_[handle]->target_device();
1608   }
1609   flr = GetFLR(target_device);
1610   if (flr != nullptr) {
1611     return flr->ReleaseHandle(handle);
1612   }
1613   return errors::InvalidArgument("Handle not found: ", handle);
1614 }
1615 
CleanupCreatedRendezvous(const Rendezvous * created_rendezvous,const int64_t step_id) const1616 void ProcessFunctionLibraryRuntime::CleanupCreatedRendezvous(
1617     const Rendezvous* created_rendezvous, const int64_t step_id) const {
1618   if (created_rendezvous) {
1619     DCHECK(rendezvous_factory_);
1620     created_rendezvous->Unref();
1621     Status s = rendezvous_factory_.CleanUp(step_id);
1622     if (!s.ok()) {
1623       LOG(ERROR) << s;
1624     }
1625   }
1626 }
1627 
1628 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64_t step_id,const Rendezvous * created_rendezvous) const1629 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1630     std::vector<std::unique_ptr<CleanUpItem>>* items,
1631     FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
1632     const Rendezvous* created_rendezvous) const {
1633   return [this, items, done = std::move(done), step_id,
1634           created_rendezvous](const Status& status) {
1635     this->CleanupCreatedRendezvous(created_rendezvous, step_id);
1636     auto* local_status = new Status(status);
1637     CleanUp(items, [local_status, done](const Status& cleanup_status) {
1638       local_status->Update(cleanup_status);
1639       done(*local_status);
1640       delete local_status;
1641     });
1642     delete items;
1643   };
1644 }
1645 
CreateRendezvous(FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1646 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1647     FunctionLibraryRuntime::Options& opts,
1648     Rendezvous** created_rendezvous) const {
1649   DCHECK(opts.rendezvous == nullptr);
1650   if (!rendezvous_factory_) {
1651     return errors::FailedPrecondition(
1652         "The caller does not provide a rendezvous and "
1653         "ProcessFunctionLibraryRuntime was created without a rendezvous "
1654         "factory.");
1655   }
1656   Status s = rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1657   if (s.ok()) {
1658     opts.rendezvous = *created_rendezvous;
1659     opts.create_rendezvous = false;
1660   }
1661   return s;
1662 }
1663 
GetComponentArgs(const gtl::ArraySlice<Tensor> args,const ProcessFunctionLibraryRuntime::ComponentFunctionData & comp_data,ProcessFunctionLibraryRuntime::InternalArgs * comp_args)1664 Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1665     const gtl::ArraySlice<Tensor> args,
1666     const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1667     ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1668   // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1669   for (const auto& it : comp_data.arg_indices) {
1670     if (it.index >= args.size()) {
1671       return errors::InvalidArgument("index ", it.index,
1672                                      " is out of range [0, ", args.size(), ")");
1673     }
1674     if (it.sub_index >= 0) {
1675       const Tensor& t = args[it.index];
1676       if (t.dtype() != DT_RESOURCE) {
1677         return errors::InvalidArgument("Got unexpected sub_index ",
1678                                        it.sub_index, " for argument ",
1679                                        it.index);
1680       }
1681       const auto& handles = t.flat<ResourceHandle>();
1682       if (it.sub_index >= handles.size()) {
1683         return errors::InvalidArgument("Sub_index ", it.sub_index,
1684                                        "is out of range [0,", handles.size(),
1685                                        ") for argument ", it.index);
1686       }
1687       comp_args->args.push_back(Tensor(handles(it.sub_index)));
1688     } else {
1689       comp_args->args.push_back(args[it.index]);
1690     }
1691   }
1692   return OkStatus();
1693 }
1694 
1695 #if !defined(IS_MOBILE_PLATFORM)
GetComponentArgs(const FunctionArgsInterface & args,const ProcessFunctionLibraryRuntime::ComponentFunctionData & comp_data,ProcessFunctionLibraryRuntime::InternalArgs * comp_args)1696 Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1697     const FunctionArgsInterface& args,
1698     const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1699     ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1700   for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1701     const FunctionArgIndex index = comp_data.arg_indices.at(i);
1702     Tensor tensor;
1703     if (args.GetLocalArg(index, &tensor).ok()) {
1704       comp_args->args.push_back(std::move(tensor));
1705     } else {
1706       eager::RemoteTensorHandle remote_handle;
1707       TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1708       comp_args->remote_args.emplace_back(
1709           std::make_unique<eager::RemoteTensorHandle>(
1710               std::move(remote_handle)));
1711       comp_args->args.push_back(comp_args->remote_args.back().get());
1712     }
1713   }
1714   return OkStatus();
1715 }
1716 #endif  // IS_MOBILE_PLATFORM
1717 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1718 void ProcessFunctionLibraryRuntime::Run(
1719     const FunctionLibraryRuntime::Options& opts,
1720     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1721     std::vector<Tensor>* rets,
1722     FunctionLibraryRuntime::DoneCallback done) const {
1723   FunctionLibraryRuntime::Options new_opts = opts;
1724   Rendezvous* created_rendezvous = nullptr;
1725   if (!opts.rendezvous) {
1726     Status s = CreateRendezvous(new_opts, &created_rendezvous);
1727     if (!s.ok()) {
1728       done(s);
1729       return;
1730     }
1731   }
1732 
1733   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1734   done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1735                                     new_opts.step_id, created_rendezvous);
1736   std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1737   done = [rets, function_rets, done = std::move(done)](const Status& s) {
1738     Status status = s;
1739     if (status.ok()) {
1740       status.Update(FunctionRetsToTensors(function_rets, rets));
1741     }
1742     delete function_rets;
1743     done(status);
1744   };
1745   bool multi_device = HasMultiDeviceHandle(handle);
1746   if (multi_device) {
1747     auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1748                                       InternalArgs* comp_args) -> Status {
1749       return GetComponentArgs(args, comp_data, comp_args);
1750     };
1751     return RunMultiDeviceAsync(new_opts, handle, function_rets, cleanup_items,
1752                                std::move(done), std::move(get_component_args));
1753   }
1754   std::vector<FunctionArg> local_args;
1755   for (const auto& tensor : args) {
1756     local_args.push_back(tensor);
1757   }
1758   RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1759               std::move(done));
1760 }
1761 
1762 // This method handles the simple remote call case (not multi-device).
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1763 void ProcessFunctionLibraryRuntime::RunInternal(
1764     const FunctionLibraryRuntime::Options& opts,
1765     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1766     std::vector<FunctionRet>* rets,
1767     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1768     FunctionLibraryRuntime::DoneCallback done) const {
1769   FunctionLibraryRuntime* flr = nullptr;
1770   string target_device;
1771   FunctionLibraryRuntime::LocalHandle local_handle;
1772   {
1773     tf_shared_lock l(mu_);
1774     auto iter = function_data_.find(handle);
1775     if (iter == function_data_.end()) {
1776       done(errors::NotFound("Handle: ", handle, " not found."));
1777       return;
1778     }
1779     FunctionData* function_data = iter->second.get();
1780     target_device = function_data->target_device();
1781     local_handle = function_data->local_handle();
1782   }
1783 
1784   if (!opts.remote_execution) {
1785     done(
1786         errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1787                                 "only be called for multi-device functions or "
1788                                 "for remote execution."));
1789     return;
1790   }
1791 
1792   flr = GetFLR(target_device);
1793   if (flr != nullptr) {
1794     auto rendezvous = opts.rendezvous;
1795     string source_device = opts.source_device;
1796     DeviceContext* device_context;
1797     Status s = GetDeviceContext(source_device, &device_context);
1798     if (!s.ok()) {
1799       done(s);
1800       return;
1801     }
1802     int64_t src_incarnation, target_incarnation;
1803     s = GetDeviceIncarnation(source_device, &src_incarnation);
1804     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1805     if (!s.ok()) {
1806       done(s);
1807       return;
1808     }
1809 
1810     std::vector<Tensor> local_args = GetLocalArgs(args);
1811 
1812     // Send the args over to the target device.
1813     s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1814                     local_args, device_context, opts.args_alloc_attrs,
1815                     rendezvous);
1816     if (!s.ok()) {
1817       done(s);
1818       return;
1819     }
1820     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1821         opts.rets_alloc_attrs;
1822     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1823     flr->Run(opts, handle, local_args, remote_rets,
1824              [source_device, target_device, target_incarnation, rendezvous,
1825               device_context, rets_alloc_attrs, remote_rets, rets,
1826               done = std::move(done)](const Status& status) mutable {
1827                if (!status.ok()) {
1828                  delete remote_rets;
1829                  done(status);
1830                  return;
1831                }
1832                int64_t num_returns = remote_rets->size();
1833                delete remote_rets;
1834                // Now receive the return values from the target.
1835                std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1836                ReceiveTensorsAsync(target_device, source_device, "ret_",
1837                                    target_incarnation, num_returns,
1838                                    device_context, rets_alloc_attrs, rendezvous,
1839                                    recv_tensors,
1840                                    TensorsToFunctionRetsDoneCallback(
1841                                        rets, recv_tensors, std::move(done)));
1842              });
1843     return;
1844   }
1845   if (parent_ != nullptr) {
1846     auto cleanup_item = std::make_unique<CleanUpItem>();
1847     cleanup_item->device = target_device;
1848     cleanup_item->step_id = opts.step_id;
1849     cleanup_item->local_handle = local_handle;
1850     cleanup_items->emplace_back(std::move(cleanup_item));
1851     parent_->Run(opts, local_handle, args, rets, std::move(done));
1852     return;
1853   }
1854   done(errors::Internal("Could not find device"));
1855 }
1856 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1857 void ProcessFunctionLibraryRuntime::Run(
1858     const FunctionLibraryRuntime::Options& opts,
1859     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1860     FunctionLibraryRuntime::DoneCallback done) const {
1861   std::vector<Tensor> args;
1862   args.reserve(frame->num_args());
1863   for (size_t i = 0; i < frame->num_args(); ++i) {
1864     const Tensor* arg;
1865     Status s = frame->GetArg(i, &arg);
1866     args.emplace_back(*arg);
1867     if (!s.ok()) {
1868       done(s);
1869     }
1870   }
1871   std::vector<Tensor>* rets = new std::vector<Tensor>;
1872   rets->reserve(frame->num_retvals());
1873 
1874   Run(opts, handle, args, rets,
1875 
1876       [frame, rets, done = std::move(done)](const Status& status) {
1877         std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1878 
1879         if (!status.ok()) {
1880           done(status);
1881           return;
1882         }
1883 
1884         if (rets->size() != frame->num_retvals()) {
1885           done(errors::Internal(
1886               "Number of return values from function (", rets->size(),
1887               ") did not match expected number of return values (",
1888               frame->num_retvals(), ")."));
1889           return;
1890         }
1891 
1892         for (size_t i = 0; i < frame->num_retvals(); ++i) {
1893           Status s = frame->SetRetval(i, (*rets)[i]);
1894           if (!s.ok()) {
1895             done(s);
1896             return;
1897           }
1898         }
1899         done(OkStatus());
1900       });
1901 }
1902 
RunSync(const FunctionLibraryRuntime::Options & orig_opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1903 Status ProcessFunctionLibraryRuntime::RunSync(
1904     const FunctionLibraryRuntime::Options& orig_opts,
1905     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1906     std::vector<Tensor>* rets) const {
1907   MultiDeviceFunctionData* multi_device_data = IsMultiDevice(handle);
1908   if (multi_device_data && multi_device_data->enable_sync_execution) {
1909     metrics::IncrementTestCounter("pflr_runsync", "sync");
1910     FunctionLibraryRuntime::Options new_opts = orig_opts;
1911     Rendezvous* created_rendezvous = nullptr;
1912     if (!new_opts.rendezvous) {
1913       TF_RETURN_IF_ERROR(CreateRendezvous(new_opts, &created_rendezvous));
1914     }
1915 
1916     std::vector<FunctionRet> function_rets;
1917     auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1918                                       InternalArgs* comp_args) {
1919       return GetComponentArgs(args, comp_data, comp_args);
1920     };
1921 
1922     Status status = RunMultiDeviceSync(new_opts, handle, &function_rets,
1923                                        std::move(get_component_args));
1924     CleanupCreatedRendezvous(created_rendezvous, new_opts.step_id);
1925     status.Update(FunctionRetsToTensors(&function_rets, rets));
1926     return status;
1927   } else {
1928     // TODO(b/207484417): Either handle or avoid/delete this fallback path.
1929     metrics::IncrementTestCounter("pflr_runsync", "async");
1930     Notification n;
1931     Status s;
1932     Run(orig_opts, handle, args, rets, [&n, &s](const Status& status) {
1933       s.Update(status);
1934       n.Notify();
1935     });
1936     n.WaitForNotification();
1937     return s;
1938   }
1939 }
1940 
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1941 Status ProcessFunctionLibraryRuntime::RunSync(
1942     const FunctionLibraryRuntime::Options& opts,
1943     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1944   // TODO(b/207485199): Implement this as synchronous code.
1945   Notification n;
1946   Status s;
1947   Run(opts, handle, frame, [&n, &s](const Status& status) {
1948     s.Update(status);
1949     n.Notify();
1950   });
1951   n.WaitForNotification();
1952   return s;
1953 }
1954 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1955 void ProcessFunctionLibraryRuntime::Run(
1956     const FunctionLibraryRuntime::Options& opts,
1957     FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1958     std::vector<FunctionRet>* rets,
1959     FunctionLibraryRuntime::DoneCallback done) const {
1960   bool has_remote_outputs = false;
1961   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1962   if (data != nullptr) {
1963     has_remote_outputs = data->has_remote_outputs;
1964   }
1965   if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1966     const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1967     std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1968     return Run(
1969         opts, handle, local_inputs, tensor_rets,
1970         TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1971   }
1972 
1973   FunctionLibraryRuntime::Options new_opts = opts;
1974   Rendezvous* created_rendezvous = nullptr;
1975   if (!opts.rendezvous) {
1976     Status s = CreateRendezvous(new_opts, &created_rendezvous);
1977     if (!s.ok()) {
1978       done(s);
1979       return;
1980     }
1981   }
1982 
1983 #if defined(IS_MOBILE_PLATFORM)
1984   done(errors::Unimplemented(
1985       "Remote inputs are not available on mobile devices."));
1986   return;
1987 #else   // !IS_MOBILE_PLATFORM
1988   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1989   done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1990                                     created_rendezvous);
1991 
1992   auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1993                                     InternalArgs* comp_args) -> Status {
1994     return GetComponentArgs(args, comp_data, comp_args);
1995   };
1996   return RunMultiDeviceAsync(new_opts, handle, rets, cleanup_items,
1997                              std::move(done), std::move(get_component_args));
1998 #endif  // !IS_MOBILE_PLATFORM
1999 }
2000 
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const2001 void ProcessFunctionLibraryRuntime::CleanUp(
2002     std::vector<std::unique_ptr<CleanUpItem>>* items,
2003     FunctionLibraryRuntime::DoneCallback done) const {
2004   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
2005   for (auto& item : *items) {
2006     refcounted_done->Ref();
2007     auto* flr = GetFLR(item->device);
2008     if (flr != nullptr) {
2009       // TODO(fishx): cleanup state for local execution.
2010       refcounted_done->UpdateStatus(
2011           errors::Internal("Cleanup items shouldn't contain local item."));
2012       refcounted_done->Unref();
2013     } else if (parent_ != nullptr) {
2014       parent_->CleanUp(item->step_id, item->local_handle,
2015                        [refcounted_done](const Status& status) {
2016                          if (!status.ok()) {
2017                            refcounted_done->UpdateStatus(status);
2018                          }
2019                          // refcounted_done is thread-safe
2020                          refcounted_done->Unref();
2021                        });
2022     } else {
2023       refcounted_done->UpdateStatus(
2024           errors::Internal("Could not find device in cleanup."));
2025       refcounted_done->Unref();
2026     }
2027   }
2028   refcounted_done->Unref();
2029 }
2030 
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const2031 Status ProcessFunctionLibraryRuntime::Clone(
2032     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
2033     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
2034     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
2035     bool skip_flib_def) const {
2036   if (skip_flib_def) {
2037     *out_lib_def = std::make_unique<FunctionLibraryDefinition>(
2038         lib_def_->default_registry(), FunctionDefLibrary{});
2039   } else {
2040     *out_lib_def = std::make_unique<FunctionLibraryDefinition>(*lib_def_);
2041   }
2042   *out_pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
2043       device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
2044       out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
2045       session_metadata_, rendezvous_factory_);
2046   {
2047     tf_shared_lock l(mu_);
2048     for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
2049   }
2050   return OkStatus();
2051 }
2052 
2053 }  // namespace tensorflow
2054