• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <iterator>
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/common_runtime/placer.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/notification.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 #include "tensorflow/core/util/ptr_util.h"
54 #include "tensorflow/core/util/reffed_status_callback.h"
55 #if !defined(IS_MOBILE_PLATFORM)
56 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
57 #endif  // IS_MOBILE_PLATFORM
58 
59 namespace tensorflow {
60 
61 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
62 
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)63 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
64     DistributedFunctionLibraryRuntime* parent, const string& function_name,
65     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
66     const FunctionLibraryRuntime::InstantiateOptions& options,
67     FunctionLibraryRuntime::DoneCallback done) {
68   {
69     mutex_lock l(mu_);
70     is_cross_process_ = true;
71     if (init_started_) {
72       init_done_.WaitForNotification();
73       done(init_result_);
74       return;
75     }
76     init_started_ = true;
77   }
78   parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
79                       [this, done](const Status& s) {
80                         init_done_.Notify();
81                         done(s);
82                       });
83 }
84 
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)85 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
86     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
87     int graph_def_version, const FunctionLibraryDefinition* lib_def,
88     const OptimizerOptions& optimizer_options,
89     thread::ThreadPool* default_thread_pool,
90     DistributedFunctionLibraryRuntime* parent,
91     const SessionMetadata* session_metadata,
92     Rendezvous::Factory rendezvous_factory)
93     : parent_(parent),
94       env_(env),
95       config_(config ? absl::make_optional(*config) : absl::nullopt),
96       device_mgr_(device_mgr),
97       lib_def_(lib_def),
98       default_thread_pool_(default_thread_pool),
99       flr_map_(new std::unordered_map<Device*,
100                                       std::unique_ptr<FunctionLibraryRuntime>>),
101       next_handle_(0),
102       session_metadata_(session_metadata),
103       rendezvous_factory_(std::move(rendezvous_factory)),
104       optimizer_options_(optimizer_options),
105       graph_def_version_(graph_def_version) {
106   if (device_mgr == nullptr) {
107     (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
108         nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
109         graph_def_version, lib_def_, default_thread_pool, optimizer_options,
110         session_metadata_, this);
111     return;
112   }
113   InitializeDeviceAndFlr();
114 }
115 
116 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)117 Status ProcessFunctionLibraryRuntime::SendTensors(
118     const string& source_device, const string& target_device,
119     const string& key_prefix, int64_t src_incarnation,
120     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
121     const std::vector<AllocatorAttributes>& alloc_attrs,
122     RendezvousInterface* rendezvous) {
123   std::vector<string> keys;
124   for (int i = 0; i < tensors_to_send.size(); ++i) {
125     string name = strings::StrCat(key_prefix, i);
126     string key = Rendezvous::CreateKey(source_device, src_incarnation,
127                                        target_device, name, FrameAndIter(0, 0));
128     keys.push_back(key);
129   }
130   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
131       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
132   return Status::OK();
133 }
134 
135 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64_t src_incarnation,int64_t num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)136 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
137     const string& source_device, const string& target_device,
138     const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
139     DeviceContext* device_context,
140     const std::vector<AllocatorAttributes>& alloc_attrs,
141     RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
142     StatusCallback done) {
143   std::vector<string> keys;
144   for (int64_t i = 0; i < num_tensors; ++i) {
145     string name = strings::StrCat(key_prefix, i);
146     string key = Rendezvous::CreateKey(source_device, src_incarnation,
147                                        target_device, name, FrameAndIter(0, 0));
148     keys.push_back(key);
149   }
150   RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
151                                  received_tensors, std::move(done));
152 }
153 
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)154 Status ProcessFunctionLibraryRuntime::GetRetTypes(
155     FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
156   FunctionLibraryRuntime* flr = nullptr;
157   {
158     tf_shared_lock l(mu_);
159     auto miter = mdevice_data_.find(h);
160     if (miter != mdevice_data_.end()) {
161       *ret_types = miter->second->ret_types_;
162       return Status::OK();
163     }
164     auto fiter = function_data_.find(h);
165     if (fiter != function_data_.end()) {
166       flr = GetFLR(fiter->second->target_device());
167     }
168   }
169   if (flr != nullptr) {
170     return flr->GetRetTypes(h, ret_types);
171   }
172   return errors::InvalidArgument("Handle ", h, " not found.");
173 }
174 
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const175 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
176     const string& device_name, int64* incarnation) const {
177   FunctionLibraryRuntime* flr = GetFLR(device_name);
178   if (flr == nullptr) {
179     return errors::InvalidArgument("Device name: ", device_name, " not found.");
180   }
181   *incarnation = flr->device()->attributes().incarnation();
182   return Status::OK();
183 }
184 
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const185 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
186     const string& device_name, DeviceContext** device_context) const {
187   *device_context = nullptr;
188   FunctionLibraryRuntime* flr = GetFLR(device_name);
189   if (flr == nullptr) {
190     return errors::InvalidArgument("Device name: ", device_name, " not found.");
191   }
192   Device* device = flr->device();
193   string device_type = device->parsed_name().type;
194   if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
195     // "TPU_SYSTEM" indicates that `device` is a CPU.
196     return Status::OK();
197   }
198 
199   if (device->IsRemoteCallAllowed()) {
200     auto* dev_info = flr->device()->tensorflow_gpu_device_info();
201     if (dev_info) {
202       *device_context = dev_info->default_context;
203       return Status::OK();
204     }
205   }
206 
207   return errors::Internal("Device type: ", device_type,
208                           " is currently unsupported for remote ",
209                           "function executions");
210 }
211 
InitializeDeviceAndFlr()212 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
213   DeviceMgr const* all_devices = device_mgr_;
214   if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
215     all_devices = parent_->remote_device_mgr();
216   }
217 
218   mutex_lock l(mu_);
219   device_set_ = std::make_shared<DeviceSet>();
220   for (auto d : all_devices->ListDevices()) {
221     device_set_->AddDevice(d);
222   }
223   for (Device* d : device_mgr_->ListDevices()) {
224     if ((*flr_map_)[d] == nullptr) {
225       (*flr_map_)[d] = NewFunctionLibraryRuntime(
226           device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
227           graph_def_version_, lib_def_, default_thread_pool_,
228           optimizer_options_, session_metadata_, this);
229     }
230   }
231 }
232 
GetFLR(const string & device_name) const233 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
234     const string& device_name) const {
235   Device* device = nullptr;
236   if (device_name != kDefaultFLRDevice) {
237     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
238       VLOG(4) << "Could not find device: " << device_name;
239       return nullptr;
240     }
241   }
242   const auto& iter = flr_map_->find(device);
243   if (iter == flr_map_->end()) {
244     VLOG(1) << "Could not find device: " << device_name
245             << "in the local process.";
246     return nullptr;
247   }
248   return iter->second.get();
249 }
250 
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)251 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
252     const string& function_key, const string& device_name,
253     FunctionLibraryRuntime::LocalHandle local_handle) {
254   mutex_lock l(mu_);
255   return AddHandleLocked(function_key, device_name, local_handle);
256 }
257 
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)258 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
259     const string& function_key, const string& device_name,
260     FunctionLibraryRuntime::LocalHandle local_handle) {
261   auto h = next_handle_;
262   function_data_[h] =
263       absl::make_unique<FunctionData>(device_name, local_handle, function_key);
264   table_[function_key] = h;
265   next_handle_++;
266   return h;
267 }
268 
269 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)270 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
271     std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
272   mutex_lock l(mu_);
273   auto h = next_handle_;
274   mdevice_data_[h] = std::move(data);
275   table_[function_key] = h;
276   next_handle_++;
277   return h;
278 }
279 
GetHandle(const string & function_key) const280 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
281     const string& function_key) const {
282   tf_shared_lock l(mu_);
283   return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
284 }
285 
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const286 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
287     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
288   return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
289 }
290 
291 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const292 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
293     const string& device_name, FunctionLibraryRuntime::Handle handle,
294     bool include_multi_device) const {
295   tf_shared_lock l(mu_);
296 
297   auto miter = mdevice_data_.find(handle);
298   if (miter != mdevice_data_.end()) {
299     if (!include_multi_device) return kInvalidLocalHandle;
300 
301     const MultiDeviceFunctionData& data = *miter->second;
302     if (data.glue_.size() != 1) return kInvalidLocalHandle;
303 
304     const auto& pair = *data.glue_.begin();
305     const string& func_device_name = pair.first;
306     const ComponentFunctionData& component_data = pair.second;
307     if (func_device_name != device_name) return kInvalidLocalHandle;
308 
309     // Replace the given handle with the handle for the single component
310     // function.
311     handle = component_data.handle;
312   }
313 
314   auto iter = function_data_.find(handle);
315   if (iter == function_data_.end()) {
316     return kInvalidLocalHandle;
317   }
318   FunctionData* function_data = iter->second.get();
319   if (function_data->target_device() != device_name) {
320     return kInvalidLocalHandle;
321   }
322   return function_data->local_handle();
323 }
324 
GetDeviceName(FunctionLibraryRuntime::Handle handle) const325 string ProcessFunctionLibraryRuntime::GetDeviceName(
326     FunctionLibraryRuntime::Handle handle) const {
327   tf_shared_lock l(mu_);
328   auto iter = function_data_.find(handle);
329   CHECK(iter != function_data_.end());
330   FunctionData* function_data = iter->second.get();
331   return function_data->target_device();
332 }
333 
334 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const335 ProcessFunctionLibraryRuntime::IsMultiDevice(
336     FunctionLibraryRuntime::Handle handle) const {
337   tf_shared_lock l(mu_);
338   const auto& it = mdevice_data_.find(handle);
339   if (it != mdevice_data_.end()) {
340     return it->second.get();
341   }
342   return nullptr;
343 }
344 
345 namespace {
346 // Sets `group` to the first colocation group specified in `node`. If no
347 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)348 void GetColocationGroup(const Node* node, string* group) {
349   // We hoist the conversion from C-style string literal to string here,
350   // so that we can avoid the many repeated calls to strlen().
351   static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
352   const AttrValue* attr_value =
353       node->attrs().Find(kColocationAttrNameStringPiece);
354   if (attr_value != nullptr && attr_value->has_list() &&
355       attr_value->list().s_size() > 0) {
356     *group = attr_value->list().s(0);
357   }
358 }
359 
AssignedOrRequestedDeviceName(const Node & node)360 const string* AssignedOrRequestedDeviceName(const Node& node) {
361   if (node.has_assigned_device_name()) {
362     return &node.assigned_device_name();
363   }
364   return &node.requested_device();
365 }
366 
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)367 Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>&
368                        input_resource_dtypes_and_shapes,
369                    const std::vector<Node*>& arg_nodes) {
370   for (Node* n : arg_nodes) {
371     int index;
372     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
373     DataType dtype;
374     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
375     if (dtype == DT_RESOURCE) {
376       auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
377       if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
378         AttrValue dtype_attr_value;
379         dtype_attr_value.mutable_list()->add_type(
380             dtype_and_shape_iter->second.dtype);
381         n->AddAttr("_handle_dtypes", dtype_attr_value);
382         TensorShapeProto shape_proto;
383         dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
384         AttrValue shape_attr_value;
385         *shape_attr_value.mutable_list()->add_shape() = shape_proto;
386         n->AddAttr("_handle_shapes", shape_attr_value);
387       }
388     }
389   }
390   return Status::OK();
391 }
392 
393 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)394 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
395   std::vector<Tensor> tensors;
396   for (const auto& arg : args) {
397     if (arg.index() == 0) {
398       tensors.push_back(absl::get<Tensor>(arg));
399     }
400   }
401   return tensors;
402 }
403 
404 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)405 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
406     std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
407     FunctionLibraryRuntime::DoneCallback done) {
408   return [rets, tensors, done = std::move(done)](const Status& s) {
409     if (s.ok()) {
410       for (const auto& t : *tensors) {
411         rets->push_back(t);
412       }
413     }
414     delete tensors;
415     done(s);
416   };
417 }
418 
419 }  // anonymous namespace
420 
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,const FunctionLibraryDefinition * lib_def,Device * default_device)421 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
422     const std::vector<string>& input_devices,
423     const std::vector<string>& output_devices, const DeviceSet& device_set,
424     const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
425     const FunctionLibraryDefinition* lib_def, Device* default_device) {
426   // If output_devices are not specified, we want to set the output device
427   // based on the device of the output producing node. The output producing
428   // node can be an arg node because functions can simply return their
429   // arguments. To make sure that the output producing nodes have assigned
430   // devices, we assign them to arguments first.
431   for (Node* node : arg_nodes) {
432     const AttrValue* attr_value;
433     TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
434     int64_t index = attr_value->i();
435     node->set_assigned_device_name(input_devices[index]);
436   }
437 
438   for (Node* node : ret_nodes) {
439     if (output_devices.empty()) {
440       DataType dtype;
441       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
442 
443       VLOG(3) << "Trying to determine device for node " << node->name()
444               << "[T=" << DataTypeString(dtype) << "]";
445 
446       // If output_devices are empty, the node producing retval
447       // must have explicitly assigned device or a colocation constraint
448       // to a node with explicitly assigned device.
449       for (const auto& it : node->in_edges()) {
450         if (it->IsControlEdge()) continue;
451 
452         Node* src_node = it->src();
453         const string* src_device = AssignedOrRequestedDeviceName(*src_node);
454         string colocation_group = "";
455         GetColocationGroup(src_node, &colocation_group);
456         VLOG(3) << "Considering src: " << src_node->name()
457                 << " src_device: " << *src_device
458                 << " colo group: " << colocation_group;
459         while (src_device->empty() && colocation_group.empty() &&
460                src_node->IsIdentity()) {
461           // Only follows the real data input of Identity, not control edges.
462           Node* input_node;
463           TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
464           src_node = input_node;
465 
466           src_device = AssignedOrRequestedDeviceName(*src_node);
467           GetColocationGroup(src_node, &colocation_group);
468           VLOG(3) << "Considering src: " << src_node->name()
469                   << " src_device: " << *src_device
470                   << " colo group: " << colocation_group;
471         }
472 
473         // If resource is produced by a function call node, we can't trust
474         // source node device assignment, because multi-device functions can
475         // return resource placed on multiple devices. In such case we leave
476         // retval device assignment empty, and rely on placer to infer correct
477         // assignment based on actual output device.
478         const bool can_use_src_node_device =
479             !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node));
480 
481         if (!colocation_group.empty()) {
482           AttrValue::ListValue colo_attr;
483           colo_attr.add_s(colocation_group);
484           std::vector<string> colo_slice = {colocation_group};
485           node->AddAttr(kColocationAttrName, colo_slice);
486         } else if (!src_device->empty() && can_use_src_node_device) {
487           // src_device can be a partially specified device. Find the
488           // matching device in the device_set.
489           DeviceNameUtils::ParsedName parsed;
490           if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
491             return errors::InvalidArgument(
492                 "Failed to parse explicit device specification ", *src_device);
493           }
494           std::vector<Device*> matching_devices;
495           device_set.FindMatchingDevices(parsed, &matching_devices);
496           if (matching_devices.empty()) {
497             if (default_device != nullptr) {
498               matching_devices.push_back(default_device);
499             } else {
500               return errors::InvalidArgument(
501                   "Unable to find any devices for spec ", *src_device);
502             }
503           } else if (matching_devices.size() != 1) {
504             bool on_same_task = true;
505             for (int i = 1; i < matching_devices.size(); ++i) {
506               if (!DeviceNameUtils::IsSameAddressSpace(
507                       matching_devices.at(0)->parsed_name(),
508                       matching_devices.at(i)->parsed_name())) {
509                 on_same_task = false;
510                 break;
511               }
512             }
513             // If the src node of an output is assigned to a address space (e.g.
514             // py_func), rely on placer to assign a device to the output.
515             if (on_same_task) {
516               continue;
517             }
518             // Compare with default_device if it has a narrower scope matching
519             // requested device.
520             int colocated_on_default_device = 0;
521             for (int i = 0; i < matching_devices.size(); ++i) {
522               if (DeviceNameUtils::IsSameAddressSpace(
523                       default_device->parsed_name(),
524                       matching_devices.at(i)->parsed_name())) {
525                 colocated_on_default_device++;
526               }
527             }
528             // Continue to raise error if multiple colocated devices are
529             // found.
530             if (colocated_on_default_device == 1) {
531               continue;
532             }
533 
534             // Convert a vector of devices to a string.
535             // Using absl::StrJoin did not work in Android builds.
536             string devices = "[";
537             for (Device* device : matching_devices) {
538               devices.append(device->name());
539               devices.append(", ");
540             }
541             if (devices.size() > 2) {
542               devices.resize(devices.size() - 2);
543             }
544             devices.append("]");
545 
546             return errors::InvalidArgument(
547                 *src_device,
548                 "When FunctionLibraryRuntime::Options.output_devices are "
549                 "not specified for a multi-device function, the device "
550                 "specification on the output node must match exactly one "
551                 "device. Matched devices are ",
552                 devices);
553           }
554           VLOG(3) << "Setting output device to " << matching_devices[0]->name()
555                   << " for node " << SummarizeNode(*node);
556           node->set_assigned_device_name(matching_devices[0]->name());
557         } else if (!src_device->empty() && !can_use_src_node_device) {
558           VLOG(3) << "Did not set device for a resource output node "
559                   << SummarizeNode(*node);
560         }
561       }
562     } else {
563       const AttrValue* attr_value;
564       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
565       int64_t index = attr_value->i();
566       // output_devices size is checked in InstantiateMultiDevice
567       DCHECK_GT(output_devices.size(), index);
568       VLOG(3) << "Setting output device to " << output_devices[index]
569               << " for return at index " << index;
570       node->set_assigned_device_name(output_devices[index]);
571     }
572   }
573   return Status::OK();
574 }
575 
576 namespace {
577 
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)578 Status ValidateNoListArguments(
579     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
580     const string& function_name) {
581   for (const OpDef::ArgDef& arg : args) {
582     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
583       return errors::InvalidArgument(
584           "Function ", function_name, " has an ", arg_type, " named \"",
585           arg.name(),
586           "\" that is a list of tensors."
587           " Multi-device functions support only single-tensor inputs "
588           " and outputs");
589     }
590   }
591   return Status::OK();
592 }
593 
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)594 Status ValidateMultiDeviceOptions(
595     const FunctionDef& fdef,
596     const FunctionLibraryRuntime::InstantiateOptions& options) {
597   const OpDef& signature = fdef.signature();
598   // Multi-device functions currently do not support list inputs or outputs.
599   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
600                                              signature.name()));
601   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
602                                              signature.name()));
603   if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
604       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
605     return errors::Unimplemented(
606         "Function '", signature.name(), "' has `",
607         FunctionLibraryDefinition::kIntsOnDeviceAttr,
608         "` attribute set. This attribute is not currently supported by "
609         "multi-device functions.");
610   }
611   if (options.input_devices.size() != signature.input_arg_size()) {
612     return errors::InvalidArgument(
613         "InstantiateOptions.input_devices must have the same length "
614         "as the number of arguments: input_devices length = ",
615         options.input_devices.size(),
616         " number of arguments = ", signature.input_arg_size());
617   }
618   if (!options.output_devices.empty() &&
619       options.output_devices.size() != signature.output_arg_size()) {
620     return errors::InvalidArgument(
621         "InstantiateOptions.output_devices must either be empty or have the "
622         "same length as the number of arguments: output_devices length = ",
623         options.output_devices.size(),
624         " number of arguments = ", signature.output_arg_size());
625   }
626   return Status::OK();
627 }
628 
629 }  // anonymous namespace
630 
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)631 Status GetGraphAndArgRets(
632     const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
633     const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
634     std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
635     std::vector<string>* ret_node_names, DataTypeVector* ret_types,
636     std::vector<string>* control_ret_node_names) {
637   std::unique_ptr<FunctionBody> fbody;
638   // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
639   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
640   if (!fbody) {
641     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
642     return errors::Internal("Failed to construct FunctionBody for ",
643                             function_name);
644   }
645   *graph = std::unique_ptr<Graph>(fbody->graph);
646   arg_nodes->reserve(fbody->arg_nodes.size());
647   std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
648             std::back_inserter(*arg_nodes));
649   ret_nodes->reserve(fbody->ret_nodes.size());
650   std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
651             std::back_inserter(*ret_nodes));
652   fbody->graph = nullptr;
653   ret_node_names->reserve(fbody->ret_nodes.size());
654   for (const Node* node : fbody->ret_nodes) {
655     ret_node_names->push_back(node->name());
656   }
657   for (const auto& ret_type : fbody->ret_types) {
658     ret_types->push_back(ret_type);
659   }
660   control_ret_node_names->reserve(fbody->control_ret_nodes.size());
661   for (const Node* node : fbody->control_ret_nodes) {
662     control_ret_node_names->push_back(node->name());
663   }
664   return Status::OK();
665 }
666 
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)667 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
668     const string& function_name, AttrSlice attrs,
669     const FunctionLibraryRuntime::InstantiateOptions& options,
670     FunctionLibraryRuntime::Handle* handle) {
671   // Check if this function has already been instantiated.
672   const string& function_key = Canonicalize(function_name, attrs, options);
673 
674   {
675     mutex_lock l(mu_);
676     const auto& it = table_.find(function_key);
677     if (it != table_.end()) {
678       *handle = it->second;
679       ++mdevice_data_[*handle]->instantiation_counter_;
680       return Status::OK();
681     }
682   }
683 
684   VLOG(1) << "Instantiating MultiDevice function \"" << function_name
685           << "\" on default device \"" << options.target << "\"";
686   if (VLOG_IS_ON(3)) {
687     int index = 0;
688     VLOG(3) << "Requested input devices:";
689     for (const string& device : options.input_devices) {
690       VLOG(3) << "    [input " << index++ << "] " << device;
691     }
692     index = 0;
693     VLOG(3) << "Requested output devices:";
694     for (const string& device : options.output_devices) {
695       VLOG(3) << "    [output " << index++ << "] " << device;
696     }
697   }
698 
699   const FunctionLibraryDefinition* lib_def =
700       options.lib_def == nullptr ? lib_def_ : options.lib_def;
701 
702   const FunctionDef* fdef = lib_def->Find(function_name);
703   if (fdef == nullptr) {
704     return errors::InvalidArgument("Failed to find function \"", function_name,
705                                    "\" in function library: ", lib_def);
706   }
707 
708   TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
709 
710   std::unique_ptr<Graph> graph;
711   std::vector<Node*> arg_nodes, ret_nodes;
712   std::vector<string> ret_node_names;
713   DataTypeVector ret_types;
714   std::vector<string> control_ret_node_names;
715 
716   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
717       function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
718       &ret_node_names, &ret_types, &control_ret_node_names));
719 
720   GraphDef graph_def;
721   graph->ToGraphDef(&graph_def);
722   FunctionLibraryDefinition reachable_lib_def =
723       lib_def->ReachableDefinitions(graph_def);
724   *graph_def.mutable_library() = reachable_lib_def.ToProto();
725   if (options.graph_collector != nullptr) {
726     options.graph_collector->CollectRawGraph(graph_def);
727   }
728 
729   Device* default_device = nullptr;
730   if (options.default_device_to_target && !options.target.empty()) {
731     // Make the `target` device the default device if nothing else is hard
732     // coded. This allows the same function definition to be specialized to
733     // different devices depending on the `PartitionedCallOp` device.
734     FunctionLibraryRuntime* flr = GetFLR(options.target);
735     if (flr == nullptr) {
736       return errors::InvalidArgument(
737           "Cannot instantiate multi-device function with target device ",
738           options.target);
739     }
740     default_device = flr->device();
741   }
742   const std::shared_ptr<DeviceSet> dev_set = device_set();
743 
744   TF_RETURN_IF_ERROR(
745       SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
746   TF_RETURN_IF_ERROR(PinArgsAndRets(
747       options.input_devices, options.output_devices, *dev_set, arg_nodes,
748       ret_nodes, lib_def_,
749       options.config_proto.allow_soft_placement() ? default_device : nullptr));
750 
751   auto data = absl::make_unique<MultiDeviceFunctionData>(
752       function_name, function_key, ret_node_names.size(),
753       std::move(reachable_lib_def), std::move(ret_types));
754 
755   // The runtime shouldn't depend on duplication between the function library
756   // owned by the graph and the one owned by the runtime. To ensure this, for
757   // now we ensure that the graph function library is empty and the runtime
758   // library receives the query from LookUps on the graph function library.
759   graph->mutable_flib_def()->set_default_registry(&data->lib_def_);
760   graph->mutable_flib_def()->Clear();
761 
762   // Do not run function/graph optimization passes for component functions,
763   // since they have already processed the main function.
764   const bool should_run_optimization_passes = !options.is_component_function;
765   if (!should_run_optimization_passes) {
766     VLOG(1) << "Skipping function/graph optimization passes when instantiating "
767                "component function "
768             << function_name;
769   }
770 
771   // Mapping from a function body node name to the control output name.
772   std::unordered_map<string, string> node_name_to_control_ret;
773 
774   bool control_rets_updated = false;
775   if (should_run_optimization_passes) {
776     TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
777         *dev_set, options.config_proto, &graph, &data->lib_def_,
778         &control_ret_node_names, &control_rets_updated));
779   }
780 
781   if (control_rets_updated) {
782     // Function graph pass may have resulted in different nodes/node names for
783     // control rets.
784     for (const auto& control_ret : control_ret_node_names) {
785       node_name_to_control_ret.emplace(control_ret, control_ret);
786     }
787   } else {
788     for (const auto& control_ret : fdef->control_ret()) {
789       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
790     }
791   }
792 
793   GraphOptimizationPassOptions optimization_options;
794   // TODO(iga): Thread other relevant options from SessionOptions.
795   SessionOptions session_options;
796   session_options.env = env_;
797   session_options.config = options.config_proto;
798   optimization_options.session_options = &session_options;
799   optimization_options.graph = &graph;
800   optimization_options.flib_def = &data->lib_def_;
801   optimization_options.device_set = dev_set.get();
802   optimization_options.is_function_graph = true;
803   std::vector<CompositeDevice*> composite_devices;
804   {
805     tf_shared_lock l(mu_);
806     for (auto* d : composite_devices_) composite_devices.push_back(d);
807   }
808   optimization_options.composite_devices = &composite_devices;
809   optimization_options.default_function_device = default_device;
810   optimization_options.function_def = fdef;
811 
812   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
813   if (should_run_optimization_passes) {
814     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
815         OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
816   }
817 
818   // TODO(b/124993244): Smartly merge options in nested defuns, and raise
819   // exceptions/warnings in case where nested function call options are ignored.
820   DumpGraph("Before calling Placer", graph.get());
821   Placer placer(graph.get(), function_name, optimization_options.flib_def,
822                 dev_set.get(), default_device,
823                 options.config_proto.allow_soft_placement(),
824                 options.config_proto.log_device_placement());
825   TF_RETURN_IF_ERROR(placer.Run());
826 
827   DumpGraph("Before running POST_PLACEMENT passes", graph.get());
828   if (should_run_optimization_passes) {
829     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
830         OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
831   }
832 
833   Device* cpu_device;
834   TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
835 
836   if (options.optimize_graph_fn) {
837     DumpGraph("Before running graph optimization fn", graph.get());
838     Status status = options.optimize_graph_fn(
839         std::move(ret_node_names), std::move(control_ret_node_names),
840         &data->lib_def_, *dev_set, cpu_device, &graph);
841     if (!status.ok()) {
842       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
843                    << status.ToString();
844     }
845     DumpGraph("After optimization", graph.get());
846   }
847 
848   DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
849   if (should_run_optimization_passes) {
850     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
851         OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
852   }
853 
854   // Expand the nodes assigned to a CompositeDevice before graph partition to
855   // avoid generating a subgraph on a virtual device for execution.
856   // This transformation should happen as late as possible, in order to run as
857   // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
858   // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
859   TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
860       options.composite_devices, graph.get()));
861 
862   if (options.graph_collector != nullptr) {
863     GraphDef def;
864     graph->ToGraphDef(&def);
865     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
866     options.graph_collector->CollectOptimizedGraph(def);
867   }
868 
869   VLOG(4) << "Main function graph to be partitioned:";
870   VLOG(4) << DebugString(graph->ToGraphDefDebug());
871 
872   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
873   TF_RETURN_IF_ERROR(
874       PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
875 
876   for (const auto& pair : subgraphs) {
877     DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
878                               pair.first, ")"),
879               pair.second.get());
880   }
881   optimization_options.graph = nullptr;
882   optimization_options.device_set = nullptr;
883   optimization_options.partition_graphs = &subgraphs;
884   // Normally POST_PARTITIONING passes are run by distributed workers.
885   // Distributed workers are currently not supported in this code path, so we
886   // run the passes here.
887   if (should_run_optimization_passes) {
888     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
889         OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
890   }
891   for (const auto& pair : subgraphs) {
892     const auto* optimized_subgraph = pair.second.get();
893     DumpGraph(
894         strings::StrCat("After all optimization passes (", pair.first, ")"),
895         optimized_subgraph);
896     if (VLOG_IS_ON(1)) {
897       DumpGraphDefToFile(
898           strings::StrCat("pflr_after_all_optimization_passes_",
899                           reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
900                           pair.first),
901           optimized_subgraph->ToGraphDefDebug());
902     }
903   }
904 
905   if (options.graph_collector != nullptr) {
906     for (const auto& pair : subgraphs) {
907       GraphDef def;
908       pair.second->ToGraphDef(&def);
909       *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
910       options.graph_collector->CollectPartitionedGraph(def);
911     }
912   }
913 
914   // We must preserve control returns in each of the function components,
915   // otherwise after function inlining we might prune side-effectful nodes.
916   const auto control_ret =
917       [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
918     const auto it = node_name_to_control_ret.find(n->name());
919     return it != node_name_to_control_ret.end()
920                ? absl::make_optional<string>(it->second)
921                : absl::nullopt;
922   };
923 
924   int i = 0;
925   // Generate a random function_name to avoid one function reuse the partition
926   // function instantiated by another function.
927   FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
928   FunctionNameGenerator name_generator(
929       data_lib_def, absl::StrCat(function_name, "_", random::New64()));
930   auto subgraph_size = subgraphs.size();
931   gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
932   BlockingCounter counter(static_cast<int>(subgraph_size));
933   auto runner = [this, subgraph_size](std::function<void()> fn) {
934     // NOTE: Only use thread pool to instantiate sub-function when there are
935     // more than 8 sub-functions. We want to avoid cost of switching thread when
936     // there are only a few sub-functions.
937     if (default_thread_pool_ != nullptr && subgraph_size > 8) {
938       default_thread_pool_->Schedule(fn);
939     } else {
940       fn();
941     }
942   };
943   for (const auto& pair : subgraphs) {
944     Status* status = &instantiate_status[i];
945     string unique_name = name_generator.GetName();
946     ComponentFunctionData* comp_data = &data->glue_[pair.first];
947     runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
948             &control_ret, &options, status, &counter, &data] {
949       const string& target = pair.first;
950 
951       const string& device_type =
952           dev_set->FindDeviceByName(target)->device_type();
953       Graph* subgraph = pair.second.get();
954 
955       status->Update(UpdateArgAndRetvalMetadata(
956           subgraph, device_type, &comp_data->arg_indices,
957           &comp_data->ret_indices, &comp_data->arg_alloc_attrs,
958           &comp_data->ret_alloc_attrs));
959       if (!status->ok()) {
960         counter.DecrementCount();
961         return;
962       }
963       FunctionDef shard;
964       status->Update(
965           GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
966       if (!status->ok()) {
967         counter.DecrementCount();
968         return;
969       }
970       status->Update(data_lib_def->AddFunctionDef(shard));
971       if (!status->ok()) {
972         counter.DecrementCount();
973         return;
974       }
975       FunctionLibraryRuntime::InstantiateOptions opts;
976       opts.executor_type = options.executor_type;
977       opts.target = target;
978       opts.lib_def = data_lib_def;
979       opts.create_kernels_eagerly = options.create_kernels_eagerly;
980       opts.state_handle = options.state_handle;
981       auto attrs = AttrSlice(&shard.attr());
982       VLOG(1) << "Start instantiating component function " << unique_name
983               << " on device " << target;
984       VLOG(4) << DebugString(shard);
985 
986       auto* component_handle = new FunctionLibraryRuntime::Handle;
987       auto done = [this, status, unique_name, comp_data, component_handle,
988                    &data, &counter](const Status& s) {
989         status->Update(s);
990 
991         VLOG(1) << "Finished instantiating component function " << unique_name
992                 << " with handle " << *component_handle << " status: " << s;
993         if (status->ok()) {
994           {
995             mutex_lock l(mu_);
996             if (function_data_[*component_handle]->is_cross_process()) {
997               data->is_cross_process_ = true;
998             }
999           }
1000           comp_data->handle = *component_handle;
1001         }
1002         delete component_handle;
1003         counter.DecrementCount();
1004       };
1005 
1006       FunctionLibraryRuntime* flr = GetFLR(opts.target);
1007       if (flr != nullptr) {
1008         // Initialize local function synchronously.
1009         Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
1010         done(s);
1011       } else {
1012         opts.ret_indices = comp_data->ret_indices;
1013         // Initialize remote function asynchronously.
1014         InstantiateRemote(unique_name, attrs, opts, component_handle, done);
1015       }
1016     });
1017     i += 1;
1018   }
1019   counter.Wait();
1020   StatusGroup group;
1021   for (auto& status : instantiate_status) {
1022     group.Update(status);
1023   }
1024   TF_RETURN_IF_ERROR(group.as_summary_status());
1025 
1026   *handle = AddMultiDeviceHandle(std::move(data), function_key);
1027   VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1028           << "\" with handle " << *handle;
1029   return Status::OK();
1030 }
1031 
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1032 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1033     FunctionLibraryRuntime::Handle handle,
1034     std::vector<Device*>* output_devices) const {
1035   MultiDeviceFunctionData* data = IsMultiDevice(handle);
1036   if (data == nullptr) {
1037     return errors::InvalidArgument(
1038         "Failed for find multi-device function handle ", handle);
1039   }
1040 
1041   for (const auto& pair : data->glue_) {
1042     const ComponentFunctionData& comp_data = pair.second;
1043     DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1044     if (comp_data.ret_indices.empty()) {
1045       continue;
1046     }
1047 
1048     const string& target = pair.first;
1049     FunctionLibraryRuntime* target_flr = GetFLR(target);
1050     Device* target_device = nullptr;
1051     Device* host = nullptr;
1052     if (target_flr == nullptr) {
1053       if (!data->has_remote_outputs) {
1054         data->has_remote_outputs = true;
1055       }
1056       target_device = device_set()->FindDeviceByName(target);
1057       string remote_host;
1058       TF_RETURN_IF_ERROR(
1059           DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1060       host = device_set()->FindDeviceByName(remote_host);
1061     } else {
1062       target_device = target_flr->device();
1063     }
1064     output_devices->resize(data->num_outputs_);
1065     for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1066       int ret_index = comp_data.ret_indices[j];
1067       if (data->ret_types_[ret_index] == DT_RESOURCE) {
1068         (*output_devices)[ret_index] = target_device;
1069       } else {
1070         (*output_devices)[ret_index] =
1071             comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1072       }
1073     }
1074   }
1075 
1076   return Status::OK();
1077 }
1078 
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1079 void ProcessFunctionLibraryRuntime::RunMultiDevice(
1080     const FunctionLibraryRuntime::Options& opts,
1081     FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
1082     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1083     FunctionLibraryRuntime::DoneCallback done,
1084     std::function<Status(const ComponentFunctionData& comp_data,
1085                          InternalArgs* args)>
1086         get_component_args) const {
1087   if (opts.create_rendezvous) {
1088     // FLR->Run() is the default entry point. It checks for cancellation,
1089     // creates rendezvous, etc.
1090     // Letting create_rendezvous through will do the wrong thing - each
1091     // component function will get a separate rendezvous created by its FLR.
1092     done(
1093         errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
1094                          "create_rendezvous=true. Please run the function "
1095                          "using FunctionLibraryRuntime::Run"));
1096     return;
1097   }
1098 
1099   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1100   if (data == nullptr) {
1101     done(errors::NotFound("Multi-device function handle ", handle,
1102                           "not found. Was the function instantiated?"));
1103     return;
1104   }
1105 
1106   VLOG(1) << "Running multi-device function " << data->function_name_;
1107   VLOG(4) << "    with " << opts.DebugString();
1108 
1109   if (data->glue_.empty()) {
1110     // Trivial case where the function body is empty.
1111     done(Status::OK());
1112     return;
1113   }
1114 
1115   // Check whether we have the right rendezvous.
1116   if (opts.rendezvous && data->is_cross_process_ &&
1117       !opts.rendezvous->is_cross_process()) {
1118     done(errors::InvalidArgument(
1119         "Running a cross process function ", data->function_name_,
1120         " without an appropriate cross process Rendezvous."));
1121     return;
1122   }
1123 
1124   // A locally created cancellation manager, used only when the caller does not
1125   // provide one in argument.
1126   std::shared_ptr<CancellationManager> local_cm;
1127   CancellationManager* cm = opts.cancellation_manager;
1128   if (cm == nullptr) {
1129     local_cm = std::make_shared<CancellationManager>();
1130     cm = local_cm.get();
1131   }
1132 
1133   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1134   for (int i = 0; i < data->glue_.size(); ++i) {
1135     refcounted_done->Ref();
1136   }
1137 
1138   FunctionLibraryRuntime::Options opts_copy = opts;
1139   for (const auto& pair : data->glue_) {
1140     const string& target = pair.first;
1141     const ComponentFunctionData& comp_data = pair.second;
1142     FunctionLibraryRuntime::Handle handle = pair.second.handle;
1143 
1144     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1145     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1146     opts_copy.cancellation_manager = cm;
1147 
1148     InternalArgs comp_args;
1149     Status s = get_component_args(comp_data, &comp_args);
1150     if (!s.ok()) {
1151       VLOG(2) << "Failed to get component function arguments: " << s;
1152       refcounted_done->UpdateStatus(s);
1153       refcounted_done->Unref();
1154       cm->StartCancel();
1155       continue;
1156     }
1157     std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1158     rets->resize(data->num_outputs_);
1159 
1160     auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1161                                   cm, local_cm, data, handle,
1162                                   target](const Status& status) {
1163       if (!status.ok()) {
1164         VLOG(2) << "Component function execution on target " << target
1165                 << " from " << data->function_name_ << " with handle " << handle
1166                 << " failed: " << status;
1167         const string function_and_msg = strings::StrCat(
1168             errors::FormatFunctionForError(data->function_name_), " ",
1169             status.error_message());
1170         refcounted_done->UpdateStatus(Status(status.code(), function_and_msg));
1171         // Cancel the execution of other component functions.
1172         cm->StartCancel();
1173       } else {
1174         VLOG(2) << "Component function execution on target " << target
1175                 << " from " << data->function_name_ << " with handle " << handle
1176                 << " succeeded.";
1177         for (int i = 0; i < comp_rets->size(); ++i) {
1178           (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1179         }
1180       }
1181       delete comp_rets;
1182       // refcounted_done is thread-safe
1183       refcounted_done->Unref();
1184     };
1185 
1186     FunctionLibraryRuntime* flr = GetFLR(target);
1187     if (flr != nullptr) {
1188       opts_copy.remote_execution = false;
1189       // When target device has private thread pool, use the target device
1190       // runner
1191       thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1192       opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1193 
1194       VLOG(1) << "Running component function on device " << target << " from "
1195               << data->function_name_ << " with handle " << handle;
1196       VLOG(4) << "    with " << opts_copy.DebugString();
1197 
1198       std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1199       flr->Run(
1200           opts_copy, handle, GetLocalArgs(comp_args.args), comp_tensor_rets,
1201           TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1202                                             std::move(component_fn_callback)));
1203     } else {
1204       opts_copy.remote_execution = true;
1205 
1206       VLOG(1) << "Running component function on device " << target << " from "
1207               << data->function_name_ << " with handle " << handle;
1208       VLOG(4) << "    with " << opts_copy.DebugString();
1209 
1210       RunInternal(opts_copy, handle, comp_args.args, comp_rets, cleanup_items,
1211                   std::move(component_fn_callback));
1212     }
1213   }
1214   refcounted_done->Unref();
1215 }
1216 
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1217 Status ProcessFunctionLibraryRuntime::Instantiate(
1218     const string& function_name, AttrSlice attrs,
1219     const FunctionLibraryRuntime::InstantiateOptions& options,
1220     FunctionLibraryRuntime::Handle* handle) {
1221   if (options.is_multi_device_function) {
1222     return InstantiateMultiDevice(function_name, attrs, options, handle);
1223   }
1224 
1225   *handle = kInvalidHandle;
1226   FunctionLibraryRuntime* flr = GetFLR(options.target);
1227   if (flr != nullptr) {
1228     return flr->Instantiate(function_name, attrs, options, handle);
1229   }
1230 
1231   Status status;
1232   Notification notification;
1233   InstantiateRemote(function_name, attrs, options, handle,
1234                     [&status, &notification](const Status& s) {
1235                       status = s;
1236                       notification.Notify();
1237                     });
1238   notification.WaitForNotification();
1239   return status;
1240 }
1241 
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1242 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1243     FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1244   tf_shared_lock l(mu_);
1245   const auto& mdevice_it = mdevice_data_.find(handle);
1246   if (mdevice_it != mdevice_data_.end()) {
1247     *is_cross_process = mdevice_it->second->is_cross_process_;
1248     return Status::OK();
1249   }
1250   const auto& it = function_data_.find(handle);
1251   if (it != function_data_.end()) {
1252     *is_cross_process = it->second->is_cross_process();
1253     return Status::OK();
1254   }
1255   return errors::InvalidArgument("Handle ", handle, " not found.");
1256 }
1257 
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1258 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1259     const string& function_name, AttrSlice attrs,
1260     const FunctionLibraryRuntime::InstantiateOptions& options,
1261     FunctionLibraryRuntime::Handle* handle,
1262     FunctionLibraryRuntime::DoneCallback done) {
1263   if (parent_ == nullptr) {
1264     done(errors::Internal(
1265         "Currently don't support instantiating functions on device: ",
1266         options.target));
1267     return;
1268   }
1269   auto target = options.target;
1270   VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1271   string function_key = Canonicalize(function_name, attrs, options);
1272   FunctionData* f;
1273   {
1274     mutex_lock l(mu_);
1275     FunctionLibraryRuntime::Handle h =
1276         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1277     if (h == kInvalidHandle || function_data_.count(h) == 0) {
1278       h = AddHandleLocked(function_key, target, kInvalidHandle);
1279     }
1280     f = function_data_[h].get();
1281     *handle = h;
1282   }
1283   f->DistributedInit(
1284       parent_, function_name,
1285       options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1286       [this, function_name, target, handle, done](const Status& s) {
1287         VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1288                 << " on: " << target << " with handle: " << *handle
1289                 << " (this: " << this << ")";
1290         done(s);
1291       });
1292 }
1293 
RemoveHandle(FunctionLibraryRuntime::Handle handle)1294 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1295     FunctionLibraryRuntime::Handle handle) {
1296   mutex_lock l(mu_);
1297   table_.erase(function_data_[handle]->function_key());
1298   function_data_.erase(handle);
1299   return Status::OK();
1300 }
1301 
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1302 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1303     FunctionLibraryRuntime::Handle handle) {
1304   std::unique_ptr<MultiDeviceFunctionData> mdata;
1305   {
1306     mutex_lock l(mu_);
1307     auto it = mdevice_data_.find(handle);
1308     --it->second->instantiation_counter_;
1309     if (it->second->instantiation_counter_ != 0) {
1310       return Status::OK();
1311     }
1312     mdata = std::move(it->second);
1313     table_.erase(mdata->function_key_);
1314     mdevice_data_.erase(it);
1315   }
1316 
1317   // If we are here we are releasing the last instantiation of `handle`.
1318   // Release all component function handles.
1319   Status overall_status;
1320   for (const auto& it : mdata->glue_) {
1321     const string& device = it.first;
1322     FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1323     FunctionLibraryRuntime* flr = GetFLR(device);
1324     if (flr == nullptr) {
1325       // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1326       // parent is not null.
1327       if (parent_ != nullptr) {
1328         return errors::Unimplemented(
1329             "Releasing a multi-device component handle on a remote device is "
1330             "not yet implemented.");
1331       }
1332       return errors::InvalidArgument(
1333           "Failed to find FunctionLibraryRuntime for device ", device,
1334           " when releasing multi-device function handle ", handle);
1335     }
1336     Status status = flr->ReleaseHandle(flr_handle);
1337     if (!status.ok()) {
1338       overall_status = status;
1339     }
1340   }
1341 
1342   return overall_status;
1343 }
1344 
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1345 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1346     FunctionLibraryRuntime::Handle handle) {
1347   // Return directly if all function handles has already been released.
1348   if (flr_map_ == nullptr) return Status::OK();
1349 
1350   if (IsMultiDevice(handle)) {
1351     return ReleaseMultiDeviceHandle(handle);
1352   }
1353 
1354   FunctionLibraryRuntime* flr = nullptr;
1355   string target_device;
1356   {
1357     mutex_lock l(mu_);
1358     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1359     target_device = function_data_[handle]->target_device();
1360   }
1361   flr = GetFLR(target_device);
1362   if (flr != nullptr) {
1363     return flr->ReleaseHandle(handle);
1364   }
1365   return errors::InvalidArgument("Handle not found: ", handle);
1366 }
1367 
1368 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64_t step_id,const Rendezvous * created_rendezvous) const1369 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1370     std::vector<std::unique_ptr<CleanUpItem>>* items,
1371     FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
1372     const Rendezvous* created_rendezvous) const {
1373   return [this, items, done = std::move(done), step_id,
1374           created_rendezvous](const Status& status) {
1375     if (created_rendezvous) {
1376       DCHECK(rendezvous_factory_);
1377       created_rendezvous->Unref();
1378       Status s = rendezvous_factory_.CleanUp(step_id);
1379       if (!s.ok()) {
1380         LOG(ERROR) << s;
1381       }
1382     }
1383     auto* local_status = new Status(status);
1384     CleanUp(items, [local_status, done](const Status& cleanup_status) {
1385       local_status->Update(cleanup_status);
1386       done(*local_status);
1387       delete local_status;
1388     });
1389     delete items;
1390   };
1391 }
1392 
CreateRendezvous(const FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1393 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1394     const FunctionLibraryRuntime::Options& opts,
1395     Rendezvous** created_rendezvous) const {
1396   if (rendezvous_factory_) {
1397     return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1398   } else {
1399     return errors::FailedPrecondition(
1400         "The caller does not provide a rendezvous and "
1401         "ProcessFunctionLibraryRuntime was created without a rendezvous "
1402         "factory.");
1403   }
1404 }
1405 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1406 void ProcessFunctionLibraryRuntime::Run(
1407     const FunctionLibraryRuntime::Options& opts,
1408     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1409     std::vector<Tensor>* rets,
1410     FunctionLibraryRuntime::DoneCallback done) const {
1411   FunctionLibraryRuntime::Options new_opts = opts;
1412   Rendezvous* created_rendezvous = nullptr;
1413   if (!opts.rendezvous) {
1414     Status s = CreateRendezvous(opts, &created_rendezvous);
1415     if (!s.ok()) {
1416       done(s);
1417       return;
1418     }
1419     new_opts.rendezvous = created_rendezvous;
1420     new_opts.create_rendezvous = false;
1421   }
1422 
1423   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1424   done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1425                                     new_opts.step_id, created_rendezvous);
1426   std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1427   done = [rets, function_rets, done = std::move(done)](const Status& s) {
1428     Status status = s;
1429     if (status.ok()) {
1430       for (const auto& ret : *function_rets) {
1431         if (ret.index() == 0) {
1432           rets->push_back(absl::get<Tensor>(ret));
1433         } else {
1434           status.Update(errors::Internal(
1435               "Expect a Tensor as a function output but got a TensorShape."));
1436           break;
1437         }
1438       }
1439     }
1440     delete function_rets;
1441     done(status);
1442   };
1443   bool multi_device;
1444   {
1445     tf_shared_lock l(mu_);
1446     multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
1447   }
1448   if (multi_device) {
1449     auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1450                                       InternalArgs* comp_args) -> Status {
1451       // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1452       for (const auto& it : comp_data.arg_indices) {
1453         if (it.index >= args.size()) {
1454           return errors::InvalidArgument(
1455               "index ", it.index, " is out of range [0, ", args.size(), ")");
1456         }
1457         if (it.sub_index >= 0) {
1458           const Tensor& t = args[it.index];
1459           if (t.dtype() != DT_RESOURCE) {
1460             return errors::InvalidArgument("Got unexpected sub_index ",
1461                                            it.sub_index, " for argument ",
1462                                            it.index);
1463           }
1464           const auto& handles = t.flat<ResourceHandle>();
1465           if (it.sub_index >= handles.size()) {
1466             return errors::InvalidArgument(
1467                 "Sub_index ", it.sub_index, "is out of range [0,",
1468                 handles.size(), ") for argument ", it.index);
1469           }
1470           comp_args->args.push_back(Tensor(handles(it.sub_index)));
1471         } else {
1472           comp_args->args.push_back(args[it.index]);
1473         }
1474       }
1475       return Status::OK();
1476     };
1477     return RunMultiDevice(new_opts, handle, function_rets, cleanup_items,
1478                           std::move(done), std::move(get_component_args));
1479   }
1480   std::vector<FunctionArg> local_args;
1481   for (const auto& tensor : args) {
1482     local_args.push_back(tensor);
1483   }
1484   RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1485               std::move(done));
1486 }
1487 
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1488 void ProcessFunctionLibraryRuntime::RunInternal(
1489     const FunctionLibraryRuntime::Options& opts,
1490     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1491     std::vector<FunctionRet>* rets,
1492     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1493     FunctionLibraryRuntime::DoneCallback done) const {
1494   FunctionLibraryRuntime* flr = nullptr;
1495   string target_device;
1496   FunctionLibraryRuntime::LocalHandle local_handle;
1497   {
1498     tf_shared_lock l(mu_);
1499     auto iter = function_data_.find(handle);
1500     if (iter == function_data_.end()) {
1501       done(errors::NotFound("Handle: ", handle, " not found."));
1502       return;
1503     }
1504     FunctionData* function_data = iter->second.get();
1505     target_device = function_data->target_device();
1506     local_handle = function_data->local_handle();
1507   }
1508 
1509   if (!opts.remote_execution) {
1510     done(
1511         errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1512                                 "only be called for multi-device functions or "
1513                                 "for remote execution."));
1514     return;
1515   }
1516 
1517   flr = GetFLR(target_device);
1518   if (flr != nullptr) {
1519     auto rendezvous = opts.rendezvous;
1520     string source_device = opts.source_device;
1521     DeviceContext* device_context;
1522     Status s = GetDeviceContext(source_device, &device_context);
1523     if (!s.ok()) {
1524       done(s);
1525       return;
1526     }
1527     int64_t src_incarnation, target_incarnation;
1528     s = GetDeviceIncarnation(source_device, &src_incarnation);
1529     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1530     if (!s.ok()) {
1531       done(s);
1532       return;
1533     }
1534 
1535     std::vector<Tensor> local_args = GetLocalArgs(args);
1536 
1537     // Send the args over to the target device.
1538     s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1539                     local_args, device_context, opts.args_alloc_attrs,
1540                     rendezvous);
1541     if (!s.ok()) {
1542       done(s);
1543       return;
1544     }
1545     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1546         opts.rets_alloc_attrs;
1547     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1548     flr->Run(opts, handle, local_args, remote_rets,
1549              [source_device, target_device, target_incarnation, rendezvous,
1550               device_context, rets_alloc_attrs, remote_rets, rets,
1551               done = std::move(done)](const Status& status) mutable {
1552                if (!status.ok()) {
1553                  delete remote_rets;
1554                  done(status);
1555                  return;
1556                }
1557                int64_t num_returns = remote_rets->size();
1558                delete remote_rets;
1559                // Now receive the return values from the target.
1560                std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1561                ReceiveTensorsAsync(target_device, source_device, "ret_",
1562                                    target_incarnation, num_returns,
1563                                    device_context, rets_alloc_attrs, rendezvous,
1564                                    recv_tensors,
1565                                    TensorsToFunctionRetsDoneCallback(
1566                                        rets, recv_tensors, std::move(done)));
1567              });
1568     return;
1569   }
1570   if (parent_ != nullptr) {
1571     auto cleanup_item = absl::make_unique<CleanUpItem>();
1572     cleanup_item->device = target_device;
1573     cleanup_item->step_id = opts.step_id;
1574     cleanup_item->local_handle = local_handle;
1575     cleanup_items->emplace_back(std::move(cleanup_item));
1576     parent_->Run(opts, local_handle, args, rets, std::move(done));
1577     return;
1578   }
1579   done(errors::Internal("Could not find device"));
1580 }
1581 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1582 void ProcessFunctionLibraryRuntime::Run(
1583     const FunctionLibraryRuntime::Options& opts,
1584     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1585     FunctionLibraryRuntime::DoneCallback done) const {
1586   std::vector<Tensor> args;
1587   args.reserve(frame->num_args());
1588   for (size_t i = 0; i < frame->num_args(); ++i) {
1589     const Tensor* arg;
1590     Status s = frame->GetArg(i, &arg);
1591     args.emplace_back(*arg);
1592     if (!s.ok()) {
1593       done(s);
1594     }
1595   }
1596   std::vector<Tensor>* rets = new std::vector<Tensor>;
1597   rets->reserve(frame->num_retvals());
1598 
1599   Run(opts, handle, args, rets,
1600 
1601       [frame, rets, done = std::move(done)](const Status& status) {
1602         std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1603 
1604         if (!status.ok()) {
1605           done(status);
1606           return;
1607         }
1608 
1609         if (rets->size() != frame->num_retvals()) {
1610           done(errors::Internal(
1611               "Number of return values from function (", rets->size(),
1612               ") did not match expected number of return values (",
1613               frame->num_retvals(), ")."));
1614           return;
1615         }
1616 
1617         for (size_t i = 0; i < frame->num_retvals(); ++i) {
1618           Status s = frame->SetRetval(i, (*rets)[i]);
1619           if (!s.ok()) {
1620             done(s);
1621             return;
1622           }
1623         }
1624         done(Status::OK());
1625       });
1626 }
1627 
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1628 Status ProcessFunctionLibraryRuntime::RunSync(
1629     const FunctionLibraryRuntime::Options& opts,
1630     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1631     std::vector<Tensor>* rets) const {
1632   Notification n;
1633   Status s;
1634   Run(opts, handle, args, rets, [&n, &s](const Status& status) {
1635     s.Update(status);
1636     n.Notify();
1637   });
1638   n.WaitForNotification();
1639   return s;
1640 }
1641 
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1642 Status ProcessFunctionLibraryRuntime::RunSync(
1643     const FunctionLibraryRuntime::Options& opts,
1644     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1645   Notification n;
1646   Status s;
1647   Run(opts, handle, frame, [&n, &s](const Status& status) {
1648     s.Update(status);
1649     n.Notify();
1650   });
1651   n.WaitForNotification();
1652   return s;
1653 }
1654 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1655 void ProcessFunctionLibraryRuntime::Run(
1656     const FunctionLibraryRuntime::Options& opts,
1657     FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1658     std::vector<FunctionRet>* rets,
1659     FunctionLibraryRuntime::DoneCallback done) const {
1660   bool has_remote_outputs = false;
1661   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1662   if (data != nullptr) {
1663     has_remote_outputs = data->has_remote_outputs;
1664   }
1665   if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1666     const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1667     std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1668     return Run(
1669         opts, handle, local_inputs, tensor_rets,
1670         TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1671   }
1672 
1673   FunctionLibraryRuntime::Options new_opts = opts;
1674   Rendezvous* created_rendezvous = nullptr;
1675   if (!opts.rendezvous) {
1676     Status s = CreateRendezvous(opts, &created_rendezvous);
1677     if (!s.ok()) {
1678       done(s);
1679       return;
1680     }
1681     new_opts.rendezvous = created_rendezvous;
1682     new_opts.create_rendezvous = false;
1683   }
1684 
1685 #if defined(IS_MOBILE_PLATFORM)
1686   done(errors::Unimplemented(
1687       "Remote inputs are not available on mobile devices."));
1688   return;
1689 #else   // !IS_MOBILE_PLATFORM
1690   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1691   done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1692                                     created_rendezvous);
1693 
1694   auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1695                                     InternalArgs* comp_args) -> Status {
1696     for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1697       const FunctionArgIndex index = comp_data.arg_indices.at(i);
1698       Tensor tensor;
1699       if (args.GetLocalArg(index, &tensor).ok()) {
1700         comp_args->args.push_back(std::move(tensor));
1701       } else {
1702         eager::RemoteTensorHandle remote_handle;
1703         TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1704         comp_args->remote_args.emplace_back(
1705             absl::make_unique<eager::RemoteTensorHandle>(
1706                 std::move(remote_handle)));
1707         comp_args->args.push_back(comp_args->remote_args.back().get());
1708       }
1709     }
1710     return Status::OK();
1711   };
1712   return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
1713                         std::move(get_component_args));
1714 #endif  // !IS_MOBILE_PLATFORM
1715 }
1716 
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const1717 void ProcessFunctionLibraryRuntime::CleanUp(
1718     std::vector<std::unique_ptr<CleanUpItem>>* items,
1719     FunctionLibraryRuntime::DoneCallback done) const {
1720   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1721   for (auto& item : *items) {
1722     refcounted_done->Ref();
1723     auto* flr = GetFLR(item->device);
1724     if (flr != nullptr) {
1725       // TODO(fishx): cleanup state for local execution.
1726       refcounted_done->UpdateStatus(
1727           errors::Internal("Cleanup items shouldn't contain local item."));
1728       refcounted_done->Unref();
1729     } else if (parent_ != nullptr) {
1730       parent_->CleanUp(item->step_id, item->local_handle,
1731                        [refcounted_done](const Status& status) {
1732                          if (!status.ok()) {
1733                            refcounted_done->UpdateStatus(status);
1734                          }
1735                          // refcounted_done is thread-safe
1736                          refcounted_done->Unref();
1737                        });
1738     } else {
1739       refcounted_done->UpdateStatus(
1740           errors::Internal("Could not find device in cleanup."));
1741       refcounted_done->Unref();
1742     }
1743   }
1744   refcounted_done->Unref();
1745 }
1746 
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const1747 Status ProcessFunctionLibraryRuntime::Clone(
1748     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1749     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1750     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1751     bool skip_flib_def) const {
1752   if (skip_flib_def) {
1753     *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(
1754         lib_def_->default_registry(), FunctionDefLibrary{});
1755   } else {
1756     *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
1757   }
1758   *out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
1759       device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
1760       out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
1761       session_metadata_, rendezvous_factory_);
1762   {
1763     tf_shared_lock l(mu_);
1764     for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
1765   }
1766   return Status::OK();
1767 }
1768 
1769 }  // namespace tensorflow
1770