• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17 
18 #include <vector>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_set.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
25 #include "tensorflow/core/common_runtime/eager/execute_node.h"
26 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
27 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #ifndef __ANDROID__
33 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
34 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
35 #endif
36 #include "tensorflow/core/framework/step_stats.pb.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/flatset.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/util/ptr_util.h"
46 
47 namespace tensorflow {
48 
49 namespace {
50 
51 // Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
52 // Copied here because we don't currently compile XLA on windows. So, can't
53 // depend on it directly.
54 const char* const kXlaCompileAttr = "_XlaCompile";
55 
56 // Initializes the step stats if needed.
MaybeInitializeStepStats(StepStats * step_stats,EagerContext * ctx)57 void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
58   // Lazily initialize the RunMetadata with information about all devices if
59   // this is the first call.
60   while (step_stats->dev_stats_size() < ctx->devices()->size()) {
61     int device_idx = step_stats->dev_stats_size();
62     auto* dev_stats = step_stats->add_dev_stats();
63     dev_stats->set_device(ctx->devices()->at(device_idx)->name());
64   }
65 }
66 
StepStatsDeviceIndex(StepStats * step_stats,EagerContext * ctx,Device * device)67 int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
68                          Device* device) {
69   // Find the current device's index.
70   if (device == nullptr) {
71     device = ctx->HostCPU();
72   }
73   for (int i = 0; i < ctx->devices()->size(); ++i) {
74     if (ctx->devices()->at(i) == device ||
75         ctx->devices()->at(i)->name() == device->name()) {
76       return i;
77     }
78   }
79   // TODO(apassos) do not fall back to host CPU if device is unknown.
80   return 0;
81 }
82 
83 // This function expects *handle to point to an existing tensor handle. The
84 // function will (maybe) update the *handle to be pointed to the newly copied
85 // tensor handle.
86 //
87 // The passed in *handle will be Unreffed if it is replaced.
88 //
89 // `op_device_name` is passed in explicitly because `op->device()` might be
90 // unset and we might have selected some specific device to run this op on.
MaybeCopyInputToExpectedDevice(EagerOperation * op,const string & op_device_name,int i,const Device * expected_input_device,RunMetadata * run_metadata,TensorHandle ** handle)91 Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
92                                       const string& op_device_name, int i,
93                                       const Device* expected_input_device,
94                                       RunMetadata* run_metadata,
95                                       TensorHandle** handle) {
96   EagerContext* ctx = op->EagerContext();
97   Device* handle_device = (*handle)->device();
98   const Device* actual_device =
99       handle_device == nullptr ? ctx->HostCPU() : handle_device;
100 
101   if (expected_input_device != actual_device) {
102     switch (ctx->GetDevicePlacementPolicy()) {
103       case DEVICE_PLACEMENT_SILENT_FOR_INT32:
104         // TODO(xpan): See if we could bubble python related error up
105         // to python level.
106         if ((*handle)->dtype == DT_INT32) {
107           // Note: enabling silent copies of int32 tensors to match behavior
108           // of graph mode.
109           break;
110         }
111         TF_FALLTHROUGH_INTENDED;
112       case DEVICE_PLACEMENT_EXPLICIT:
113         return errors::InvalidArgument(
114             "Tensors on conflicting devices:"
115             " cannot compute ",
116             op->Name(), " as input #", i, " was expected to be on ",
117             expected_input_device->name(), " but is actually on ",
118             actual_device->name(), " (operation running on ", op_device_name,
119             ")",
120             " Tensors can be copied explicitly using .gpu() or .cpu() "
121             "methods,"
122             " or transparently copied by using tf.enable_eager_execution("
123             "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
124             "between devices"
125             " may slow down your model");
126       case DEVICE_PLACEMENT_WARN:
127         LOG(WARNING) << "before computing " << op->Name() << " input #" << i
128                      << " was expected to be on "
129                      << expected_input_device->name() << " but is actually on "
130                      << actual_device->name() << " (operation running on "
131                      << op_device_name
132                      << "). This triggers a copy which can be a performance "
133                         "bottleneck.";
134         break;
135       case DEVICE_PLACEMENT_SILENT:  // Do nothing.
136         break;
137     }
138     // We are only here if the policy is warn or silent copies, so we should
139     // trigger a copy.
140     auto pre_time_nanos = Env::Default()->NowNanos();
141     TensorHandle* result_handle = nullptr;
142     Status status = EagerCopyToDevice(
143         *handle, ctx, expected_input_device->name().c_str(), &result_handle);
144     if (run_metadata != nullptr) {
145       auto* step_stats = run_metadata->mutable_step_stats();
146       MaybeInitializeStepStats(step_stats, ctx);
147       // Record the sending on the source device for now.
148       int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
149       auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
150       auto* node_stats = dev_stats->add_node_stats();
151       node_stats->set_node_name("_Send");
152       node_stats->set_all_start_micros(pre_time_nanos /
153                                        EnvTime::kMicrosToNanos);
154       node_stats->set_all_start_nanos(pre_time_nanos);
155       int64 now_nanos = Env::Default()->NowNanos();
156       node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
157                                         EnvTime::kMicrosToNanos);
158       node_stats->set_op_end_rel_nanos(now_nanos - pre_time_nanos);
159       node_stats->set_all_end_rel_micros((now_nanos - pre_time_nanos) /
160                                          EnvTime::kMicrosToNanos);
161       node_stats->set_all_end_rel_nanos(now_nanos - pre_time_nanos);
162     }
163     if (!status.ok()) {
164       if (result_handle != nullptr) result_handle->Unref();
165       return errors::Internal(
166           "Failed copying input tensor from ", actual_device->name(), " to ",
167           expected_input_device->name(), " in order to run ", op->Name(), ": ",
168           status.error_message());
169     }
170 
171     (*handle)->Unref();
172     *handle = result_handle;
173   }
174   return Status::OK();
175 }
176 
177 // `op_device_name` the name of the device on which the op will run, if any.
178 // For functions running using function library runtime, the device can be
179 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,const string & op_device_name,EagerOperation * op,const KernelAndDevice * kernel,RunMetadata * run_metadata)180 Status ValidateInputTypeAndPlacement(EagerContext* ctx,
181                                      const string& op_device_name,
182                                      EagerOperation* op,
183                                      const KernelAndDevice* kernel,
184                                      RunMetadata* run_metadata) {
185   if (kernel->num_inputs() != op->Inputs().size()) {
186     return errors::InvalidArgument("expected ", kernel->num_inputs(),
187                                    " inputs, got ", op->Inputs().size());
188   }
189   for (int i = 0; i < op->Inputs().size(); ++i) {
190     const Device* expected_device = kernel->InputDevice(i);
191     TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
192         op, op_device_name, i, expected_device, run_metadata,
193         &((*op->MutableInputs())[i])));
194     tensorflow::TensorHandle* handle = op->Inputs()[i];
195     if (handle->dtype != kernel->input_type(i)) {
196       return errors::InvalidArgument(
197           "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
198           " was expected to be a ", DataTypeString(kernel->input_type(i)),
199           " tensor but is a ", DataTypeString(handle->dtype), " tensor");
200     }
201   }
202   return Status::OK();
203 }
204 
SelectDevice(const NodeDef & ndef,EagerContext * ctx,Device ** device)205 Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
206   PrioritizedDeviceTypeVector final_devices;
207   TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
208       ctx->prioritized_device_type_list(), ndef, &final_devices));
209   if (final_devices.empty()) {
210     return errors::Internal("Could not find valid device for node.\nNode: ",
211                             FormatNodeDefForError(ndef),
212                             "\nAll kernels registered for op ", ndef.op(),
213                             " :\n", KernelsRegisteredForOp(ndef.op()));
214   }
215   for (Device* d : *ctx->devices()) {
216     if (d->device_type() == final_devices[0].first.type_string()) {
217       *device = d;
218       return Status::OK();
219     }
220   }
221   return errors::Unknown("Could not find a device for node ",
222                          FormatNodeDefForError(ndef));
223 }
224 
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)225 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
226   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
227   const OpDef* op_def = nullptr;
228 
229   const FunctionDef* function_def =
230       op->EagerContext()->FuncLibDef()->Find(op->Name());
231   if (function_def != nullptr) {
232     op_def = &(function_def->signature());
233   } else {
234     TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
235   }
236 
237   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
238 
239   return Status::OK();
240 }
241 
242 }  // namespace
243 
244 namespace {
IsLocal(EagerContext * ctx,tensorflow::Device * d)245 bool IsLocal(EagerContext* ctx, tensorflow::Device* d) {
246   if (d == nullptr || ctx->remote_device_mgr() == nullptr) return true;
247   tensorflow::Device* tmp;
248   return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok();
249 }
250 
OnSameTask(EagerContext * ctx,Device * first,Device * second)251 bool OnSameTask(EagerContext* ctx, Device* first, Device* second) {
252   if (first == nullptr) first = ctx->HostCPU();
253   if (second == nullptr) second = ctx->HostCPU();
254   return first->parsed_name().job == second->parsed_name().job &&
255          first->parsed_name().replica == second->parsed_name().replica &&
256          first->parsed_name().task == second->parsed_name().task;
257 }
258 
259 // Gets the CPU device on the task of device.
CPUDeviceOnTask(EagerContext * ctx,tensorflow::Device * device,tensorflow::Device ** cpu_device)260 Status CPUDeviceOnTask(EagerContext* ctx, tensorflow::Device* device,
261                        tensorflow::Device** cpu_device) {
262   string cpu_device_name;
263   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
264       device->name(), &cpu_device_name));
265 
266   return ctx->FindDeviceByName(cpu_device_name, cpu_device);
267 }
268 
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)269 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
270                                                const tensorflow::Fprint128& b) {
271   return {tensorflow::FingerprintCat64(a.low64, b.low64),
272           tensorflow::FingerprintCat64(a.high64, b.high64)};
273 }
274 
FindDeviceFromName(const EagerContext * ctx,const char * device_name,Device ** device)275 Status FindDeviceFromName(const EagerContext* ctx, const char* device_name,
276                           Device** device) {
277   *device = ctx->HostCPU();
278   if (device_name == nullptr || strlen(device_name) == 0) {
279     return Status::OK();
280   }
281 
282   auto status = ctx->local_device_mgr()->LookupDevice(device_name, device);
283   if (status.ok()) {
284     return status;
285   }
286 
287   if (ctx->remote_device_mgr() != nullptr) {
288     return ctx->remote_device_mgr()->LookupDevice(device_name, device);
289   }
290 
291   return status;
292 }
293 
IsMultiDevice(const FunctionDef * fdef,const string & op_device)294 bool IsMultiDevice(const FunctionDef* fdef, const string& op_device) {
295   if (fdef == nullptr) {
296     // Primitive op.
297     return false;
298   }
299 
300   // Run all functions as multi-device.
301   return true;
302 
303   // We can eliminate some overhead by running simple functions using regular
304   // CallOp kernel. However, it is tricky to figure out which functions should
305   // be run using CallOp. Also, currently CallOp runs neither optimization
306   // passes (needed for TPU/XLA) nor grappler.
307   // Here are some cases where a function should be run in multi-device mode:
308   //  - Function takes at least two resources on different devices.
309   //  - Function takes a resource on deviceA and a body op explicitly placed
310   //  on deviceB.
311   //  - Function has a colocation constraint.
312   //  - Function has an explicit device annotation (which might not be using
313   //    full canonical device name) different from op_device. Note that false
314   //    positives are ok.
315   //  - Function has a node or a (node) attribute that can potentially make
316   //    the function multi-device after a rewrite pass (e.g. various XLA/TPU
317   //    special nodes and attributes)
318 }
319 
AddInputDevicesToCacheKey(const EagerContext * ctx,const EagerOperation * op,std::vector<Device * > * input_dev_ptrs,Fprint128 * cache_key)320 Status AddInputDevicesToCacheKey(const EagerContext* ctx,
321                                  const EagerOperation* op,
322                                  std::vector<Device*>* input_dev_ptrs,
323                                  Fprint128* cache_key) {
324   input_dev_ptrs->reserve(op->Inputs().size());
325   Device* cpu_device = ctx->HostCPU();
326   for (TensorHandle* tensor_handle : op->Inputs()) {
327     string device_name;
328     if (tensor_handle->dtype == DT_RESOURCE) {
329       // Use the resource's actual device because it is the device that will
330       // influence partitioning the multi-device function.
331       const Tensor* tensor;
332       TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
333       const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
334       device_name = handle.device();
335 
336       Device* input_device;
337       TF_RETURN_IF_ERROR(
338           FindDeviceFromName(ctx, device_name.c_str(), &input_device));
339       input_dev_ptrs->push_back(input_device);
340     } else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
341       input_dev_ptrs->push_back(cpu_device);
342     } else {
343       Device* device = tensor_handle->device();
344       device_name = device != nullptr ? device->name() : cpu_device->name();
345       input_dev_ptrs->push_back(device == nullptr ? cpu_device : device);
346     }
347     *cache_key = FingerprintCat128(*cache_key, Fingerprint128(device_name));
348   }
349   return Status::OK();
350 }
351 
352 // There are a lot of references to devices in this function and around.
353 // Here is what they mean:
354 //  EagerOperation::Device(): The device on which the user requested the op
355 //    be executed, except if we had to change the device due to resource inputs
356 //    or CPU pinning. If the user did not request a device, the op does not
357 //    take resources, and we did not pin it to CPU, the device can be nullptr.
358 //  KernelAndDevice::Device(): The first time we see an op (combined with
359 //    its attributes), we need to create a KernelAndDevice object for it.
360 //    If op->Device() is a nullptr, we select a device for the op when
361 //    creating the KernelAndDevice. A concrete device will always be selected
362 //    here except when `op` is a function to be executed using function library
363 //    runtime. In this case, we don't select a device because running
364 //    a function with explicitly requested device has different behavior than
365 //    running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,gtl::InlinedVector<TensorHandle *,2> * retvals,int * num_retvals)366 Status EagerLocalExecute(EagerOperation* op,
367                          gtl::InlinedVector<TensorHandle*, 2>* retvals,
368                          int* num_retvals) {
369   const string unspecified_device_name("<unspecified>");
370   EagerContext* ctx = op->EagerContext();
371   auto status = ctx->GetStatus();
372   if (!status.ok()) return status;
373   Device* device = op->Device();
374 
375   const string& maybe_unspecified_device_name =
376       device == nullptr ? unspecified_device_name : device->name();
377   Fprint128 cache_key =
378       op->MutableAttrs()->CacheKey(maybe_unspecified_device_name);
379 
380   bool is_multi_device_function = IsMultiDevice(
381       ctx->FindFunctionDef(op->Name()), maybe_unspecified_device_name);
382 
383   std::vector<Device*> input_dev_ptrs;
384   if (is_multi_device_function) {
385     TF_RETURN_IF_ERROR(
386         AddInputDevicesToCacheKey(ctx, op, &input_dev_ptrs, &cache_key));
387   }
388 
389   KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
390   if (kernel == nullptr) {
391     VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
392             << maybe_unspecified_device_name;
393     // If we are running a function on explicitly requested TPU,
394     // compile it with XLA.
395     // Note that it is not ideal, but currently ok, to set this
396     // attribute after computing the kernel cache key above.
397     bool compile_with_xla = false;
398     if (op->is_function() && device != nullptr &&
399         (device->device_type() == "TPU" || device->device_type() == "XLA_GPU" ||
400          device->device_type() == "XLA_CPU")) {
401       op->MutableAttrs()->Set(kXlaCompileAttr, true);
402       compile_with_xla = true;
403     }
404     bool run_function_with_flr = is_multi_device_function && !compile_with_xla;
405 
406     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
407     if (!run_function_with_flr && device == nullptr) {
408       status = SelectDevice(ndef, ctx, &device);
409       if (!status.ok()) return status;
410     }
411     const string& device_name =
412         device == nullptr ? unspecified_device_name : device->name();
413     if (ctx->LogDevicePlacement()) {
414       LOG(INFO) << "Executing op " << ndef.op() << " in device " << device_name;
415     } else {
416       VLOG(1) << "Executing op " << ndef.op() << " in device " << device_name;
417     }
418 
419     FunctionLibraryRuntime* flr =
420         device == nullptr ? nullptr : ctx->func_lib(device);
421     if (device != nullptr && flr == nullptr) {
422       return errors::Unavailable(
423           "Unable to find a FunctionLibraryRuntime corresponding to device ",
424           device->name());
425     }
426     auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
427                                                                : ctx->runner();
428     GraphCollector* graph_collector = nullptr;
429     if (ctx->ShouldStoreGraphs()) {
430       graph_collector = ctx->GetGraphCollector();
431     }
432     // Treat the function as multi_device only when we are not compiling
433     // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
434     // will create an XlaLaunchOp kernel to compile and run the function.
435     if (run_function_with_flr) {
436       // Multi-device functions don't use the rendezvous from eager context.
437       // If we use that rendezvous, multiple concurrent calls to the same
438       // function will likely result in collisions. However, this also means
439       // that we don't support legitimate sending/receiving across function
440       // boundary.
441       VLOG(2) << "Running " << ndef.op() << " using multi-device function. "
442               << "compile_with_xla=" << compile_with_xla
443               << ". Full node_def=" << ndef.DebugString();
444       kernel = new KernelAndDeviceFunc(
445           flr, ctx->pflr(), std::move(input_dev_ptrs), runner,
446           ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
447     } else {
448       VLOG(2) << "Running " << ndef.op() << " using op kernel. "
449               << "compile_with_xla=" << compile_with_xla
450               << ". Full node_def=" << ndef.DebugString();
451       kernel = new KernelAndDeviceOp(
452           ctx->GetRendezvous(), ctx->LogMemory(), flr, runner,
453           ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
454     }
455 
456     status = kernel->Init(ndef, graph_collector);
457     if (!status.ok()) {
458       delete kernel;
459       return status;
460     }
461 
462     ctx->AddKernelToCache(cache_key, kernel);
463   }
464   const DataTypeVector& output_dtypes = kernel->output_dtypes();
465   const int output_dtypes_size = static_cast<int>(output_dtypes.size());
466   if (output_dtypes_size > *num_retvals) {
467     return errors::InvalidArgument("Expecting ", output_dtypes.size(),
468                                    " outputs, but *num_retvals is ",
469                                    *num_retvals);
470   }
471   *num_retvals = output_dtypes_size;
472   const string& device_name = kernel->device() == nullptr
473                                   ? unspecified_device_name
474                                   : kernel->device()->name();
475   status = ValidateInputTypeAndPlacement(
476       ctx, device_name, op, kernel,
477       ctx->ShouldStoreStepStats() ? ctx->RunMetadataProto() : nullptr);
478   if (!status.ok()) return status;
479   std::unique_ptr<NodeExecStats> maybe_stats;
480   StepStats* maybe_step_stats = nullptr;
481   GraphCollector* graph_collector = nullptr;
482   if (ctx->ShouldStoreGraphs()) {
483     graph_collector = ctx->GetGraphCollector();
484   }
485   if (ctx->ShouldStoreStepStats()) {
486     maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
487     int64 now_nanos = Env::Default()->NowNanos();
488     maybe_stats.reset(new NodeExecStats);
489     maybe_stats->set_node_name(op->Name());
490     maybe_stats->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
491     maybe_stats->set_all_start_nanos(now_nanos);
492     maybe_stats->set_op_start_rel_micros(0);
493     maybe_stats->set_op_start_rel_nanos(0);
494     maybe_stats->set_scheduled_micros(now_nanos / EnvTime::kMicrosToNanos);
495     maybe_stats->set_scheduled_nanos(now_nanos);
496     // TODO(apassos) track referenced tensors
497   }
498   retvals->resize(*num_retvals);
499   if (ctx->Async()) {
500     // Note that for async mode, execution order will make sure that all
501     // input handles are ready before executing them.
502     // TODO(agarwal): Consider executing "cheap" kernels inline for
503     // performance.
504     tensorflow::uint64 id = ctx->NextId();
505     for (int i = 0; i < *num_retvals; ++i) {
506       (*retvals)[i] = new TensorHandle(
507           id, /* d= */ kernel->OutputDevice(i),
508           /* op_device= */ kernel->device(),
509           /* resource_device= */ kernel->OutputResourceDevice(i),
510           output_dtypes[i], ctx);
511     }
512     EagerNode* node = new ExecuteNode(id, ctx, op->Inputs(), kernel,
513                                       maybe_stats.release(), maybe_step_stats,
514                                       graph_collector, output_dtypes, *retvals);
515     ctx->ExecutorAdd(node);
516   } else {
517     // Execute checks if retvals[i] is nullptr or not to figure if it needs to
518     // allocate it.
519     status = EagerKernelExecute(ctx, op->Inputs(), kernel, maybe_stats.get(),
520                                 maybe_step_stats, graph_collector,
521                                 retvals->data(), *num_retvals);
522   }
523 
524   return status;
525 }
526 
527 #ifndef __ANDROID__
GetRemoteTensorDestructor(EagerContext * ctx,eager::EagerClient * eager_client,uint64 context_id,uint64 op_id,int output_num)528 std::function<void()> GetRemoteTensorDestructor(
529     EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
530     uint64 op_id, int output_num) {
531   return [ctx, eager_client, context_id, op_id, output_num]() {
532     if (!ctx->HasActiveRemoteContext(context_id)) {
533       // This means that this tensor was pointing to a remote device, which
534       // has been changed out from under us. Simply return since there is
535       // nothing we can do.
536       return tensorflow::Status::OK();
537     }
538 
539     std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
540     request->set_context_id(context_id);
541 
542     auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
543     handle_to_decref->set_op_id(op_id);
544     handle_to_decref->set_output_num(output_num);
545 
546     if (ctx->Async()) {
547       tensorflow::uint64 id = ctx->NextId();
548       auto* node =
549           new eager::RemoteExecuteNode(id, std::move(request), eager_client);
550       ctx->ExecutorAdd(node);
551     } else {
552       eager::EnqueueRequest* actual_request = request.release();
553       eager::EnqueueResponse* response = new eager::EnqueueResponse;
554       eager_client->EnqueueAsync(
555           actual_request, response,
556           [actual_request, response](const tensorflow::Status& s) {
557             delete actual_request;
558             delete response;
559           });
560     }
561 
562     return tensorflow::Status::OK();
563   };
564 }
565 #endif
566 
567 // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
568 // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
569 // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
570 //
571 // However, in some configurations the node that has the tensor to be copied
572 // isn't running a server (WorkerService RPC interface). For such cases,
573 // this function enables sending tensors using the EagerService.SendTensor RPC
574 // *on the receiver*.
EagerRemoteSendTensor(EagerContext * ctx,TensorHandle * h,Device * recv_device,TensorHandle ** result)575 Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
576                              Device* recv_device, TensorHandle** result) {
577 #ifdef __ANDROID__
578   return errors::Unimplemented(
579       "Eager's remote execution is not available on Android devices.");
580 #else
581   eager::EagerClient* eager_client;
582   uint64 context_id;
583   TF_RETURN_IF_ERROR(
584       ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
585 
586   eager::SendTensorRequest request;
587   eager::SendTensorResponse response;
588 
589   request.set_context_id(context_id);
590   request.set_op_id(ctx->NextId());
591   request.set_device_name(recv_device->name());
592 
593   Device* tensor_handle_device = h->device();
594 
595   // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
596   // copy it to the CPU before copying it out.
597   // TODO(nareshmodi): this is currently slow, but can be fixed by making
598   // tensor handles aware of more than one device.
599   TensorHandle* actual_handle;
600   if (tensor_handle_device != nullptr &&
601       tensor_handle_device->device_type() != "CPU") {
602     TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, ctx->HostCPU(), &actual_handle));
603   } else {
604     actual_handle = h;
605     actual_handle->Ref();
606   }
607 
608   const Tensor* tensor;
609   TF_RETURN_IF_ERROR(actual_handle->Tensor(&tensor));
610   tensor->AsProtoTensorContent(request.add_tensors());
611 
612   const tensorflow::uint64 id = request.op_id();
613 
614   // TODO(nareshmodi): support making this call async.
615   Notification n;
616   Status status;
617   eager_client->SendTensorAsync(&request, &response,
618                                 [&n, &status](const Status& s) {
619                                   status = s;
620                                   n.Notify();
621                                 });
622   n.WaitForNotification();
623   if (!status.ok()) return status;
624 
625   std::function<void()> destructor =
626       GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0);
627 
628   *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0,
629                              tensor->dtype(), std::move(destructor),
630                              /*d=*/recv_device, /*op_device=*/recv_device,
631                              /*resource_device=*/nullptr, ctx);
632   (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape()));
633 
634   actual_handle->Unref();
635 
636   return Status::OK();
637 #endif
638 }
639 
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)640 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
641                           int* num_retvals) {
642 #ifdef __ANDROID__
643   return errors::Unimplemented(
644       "Eager's remote execution is not available on Android devices.");
645 #else
646   EagerContext* ctx = op->EagerContext();
647 
648   eager::EagerClient* eager_client;
649   uint64 context_id;
650   TF_RETURN_IF_ERROR(
651       ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
652 
653   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
654   eager::EnqueueResponse response;
655 
656   request->set_context_id(context_id);
657 
658   auto* remote_op = request->add_queue()->mutable_operation();
659 
660   for (int i = 0; i < op->Inputs().size(); i++) {
661     tensorflow::Device* input_device = op->Inputs()[i]->device();
662     if (op->Device() != input_device &&
663         // If the expected and actual devices are on the same task, don't
664         // explicitly copy, and instead depend on the copy to happen locally
665         // when the op is executed on the device.
666         !OnSameTask(ctx, op->Device(), input_device)) {
667       tensorflow::Device* remote_cpu_device;
668       TF_RETURN_IF_ERROR(
669           CPUDeviceOnTask(ctx, op->Device(), &remote_cpu_device));
670       // TODO(b/110044833): It's possible the same tensor gets copied to the
671       // remote device repeatedly.
672       // Always copy to the remote CPU so that the actual device can be
673       // correctly determined after the kernel is selected/instantiated, since
674       // the op might have its inputs on host memory.
675       TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
676           op, op->Device()->name(), i, remote_cpu_device,
677           /* run_metadata= */ nullptr, &(*op->MutableInputs())[i]));
678     }
679 
680     tensorflow::TensorHandle* input = op->Inputs()[i];
681 
682     tensorflow::int64 op_id;
683     int32 output_num;
684     TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
685 
686     auto* remote_op_input = remote_op->add_inputs();
687     remote_op_input->set_op_id(op_id);
688     remote_op_input->set_output_num(output_num);
689   }
690 
691   remote_op->set_id(op->EagerContext()->NextId());
692   remote_op->set_name(op->Name());
693   // Inputs set above.
694   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
695   remote_op->set_device(op->Device()->name());
696 
697   DataTypeVector output_dtypes;
698   TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
699 
700   if (*num_retvals != output_dtypes.size()) {
701     return errors::InvalidArgument(
702         "num_retvals does not match expected output dtypes");
703   }
704 
705   tensorflow::Device* op_device = op->Device();
706 
707   bool is_async = op->EagerContext()->Async();
708   uint64 remote_node_id = 0;
709 
710   if (is_async) {
711     remote_node_id = op->EagerContext()->NextId();
712   }
713 
714   const tensorflow::uint64 id = remote_op->id();
715   for (int i = 0; i < *num_retvals; i++) {
716     // TODO(nareshmodi): Change the callback to instead add the decref to a
717     // list of pending decrefs that we can send as a batch with the next
718     // execute.
719     std::function<void()> destructor =
720         GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i);
721 
722     // The device_ and resource_device_ or this TensorHandle are not correct.
723     // It is pretty hard to make it correct because for multi-device functions,
724     // we don't know the output device until the function is instantiated.
725     // Luckily, we don't need to know the correct remote device here. We just
726     // need to know that it is remote. If we need to copy this tensor to this
727     // process, the remote end will know the correct device of this handle.
728     retvals[i] = new TensorHandle(
729         remote_op->id(), i, remote_node_id, output_dtypes[i],
730         std::move(destructor),
731         /*d=*/op_device, /*op_device=*/op_device,
732         /*resource_device=*/output_dtypes[i] == DT_RESOURCE ? op_device
733                                                             : nullptr,
734         op->EagerContext());
735   }
736 
737   if (is_async) {
738     // Copy the output handles, since the container for them might get
739     // destroyed.
740     gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
741     for (int i = 0; i < *num_retvals; i++) {
742       retvals_copy.push_back(retvals[i]);
743       retvals_copy[i]->Ref();
744     }
745     // Unable to capture via std::move, so bind instead.
746     auto* node = new eager::RemoteExecuteNode(
747         remote_node_id, std::move(request), eager_client, op->Inputs(),
748         std::bind(
749             [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
750                const Status& status, const eager::EnqueueResponse& response) {
751               if (!status.ok()) return;
752               for (int i = 0; i < retvals.size(); i++) {
753                 retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
754                     response.queue_response(0).shape(i)));
755                 retvals[i]->Unref();
756               }
757             },
758             std::move(retvals_copy), std::placeholders::_1,
759             std::placeholders::_2));
760     op->EagerContext()->ExecutorAdd(node);
761   } else {
762     Notification n;
763     Status status;
764     eager_client->EnqueueAsync(request.get(), &response,
765                                [&n, &status](const Status& s) {
766                                  status = s;
767                                  n.Notify();
768                                });
769     n.WaitForNotification();
770 
771     if (!status.ok()) return status;
772 
773     for (int i = 0; i < *num_retvals; i++) {
774       retvals[i]->SetRemoteShape(
775           MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
776     }
777   }
778 
779   return Status::OK();
780 #endif
781 }
782 
783 // These ops are not pinnable since they generate data. It can be slower to
784 // generate and then copy the data instead of just generating the data on the
785 // device directly.
IsPinnableOp(const string & op_type)786 bool IsPinnableOp(const string& op_type) {
787   static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
788       "RandomUniform",
789       "RandomUniformInt",
790       "RandomStandardNormal",
791       "StatelessRandomUniform",
792       "StatelessRandomUniformInt",
793       "StatelessRandomNormal",
794   });
795 
796   // XRT ops refer to per-device handles that are not safe to move between
797   // devices.
798   return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
799          !absl::StartsWith(op_type, "XRT");
800 }
801 
802 // The Op device may be updated if:
803 // - A resource touching input is specified: all resource-touching ops run in
804 // the device the resource is, regardless of anything else that has been
805 // specified. This is identical to the graph mode behavior.
806 //
807 // - All op inputs are on the CPU, small (<64 elements) and integers
808 // (int32/int64). This can be disabled by setting the environment variable
809 // "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
MaybeUpdateOpDevice(EagerOperation * op)810 Status MaybeUpdateOpDevice(EagerOperation* op) {
811   if (op->is_function()) {
812     // Don't update the device of direct function calls.
813     // Particularly, if the user did not explicitly request any device for this
814     // function, picking a device would result in this device being the default
815     // for nodes inside the function. This is undesirable for multi-device
816     // functions since the not-explicitly-placed nodes inside the body will all
817     // end up on this default device.
818     return Status::OK();
819   }
820   EagerContext* ctx = op->EagerContext();
821   bool all_inputs_eligible_for_cpu_pinning =
822       ctx->PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
823   Device* op_device = op->Device() == nullptr ? ctx->HostCPU() : op->Device();
824   for (int i = 0; i < op->Inputs().size(); ++i) {
825     TensorHandle* tensor_handle = op->Inputs()[i];
826     if (tensor_handle->dtype == DT_RESOURCE) {
827       Device* resource_device = tensor_handle->resource_device();
828       VLOG(2) << "for op " << op->Name() << " input " << i << " "
829               << DataTypeString(tensor_handle->dtype)
830               << " input device = " << resource_device->name()
831               << ", op device = " << op_device->name();
832       // We check for `op->Device() == nullptr` because it can be later
833       // interpreted as unspecified device and a different device can
834       // be selected based on device priority. If any input to an op
835       // is a resource we must pin it to prevent different device selection.
836       // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
837       if (resource_device != op_device || op->Device() == nullptr) {
838         VLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
839                 << "device of operation " << op->Name() << " to "
840                 << resource_device->name() << " because input #" << i
841                 << " is a resource in this device.";
842         op->SetDevice(resource_device);
843       }
844       all_inputs_eligible_for_cpu_pinning = false;
845       // No point in looking at other inputs. If there are other resources,
846       // they must have the same device and we already declared the op to be
847       // ineligible for CPU pinning.
848       break;
849     } else if (all_inputs_eligible_for_cpu_pinning) {
850       Device* input_device = tensor_handle->device();
851       input_device = input_device == nullptr ? ctx->HostCPU() : input_device;
852       VLOG(2) << "for op " << op->Name() << " input " << i << " "
853               << DataTypeString(tensor_handle->dtype)
854               << " input device = " << input_device->name()
855               << ", op device = " << op_device->name();
856 
857       // Input is on CPU.
858       if (input_device != ctx->HostCPU()) {
859         all_inputs_eligible_for_cpu_pinning = false;
860         continue;
861       }
862 
863       if (tensor_handle->dtype != DataType::DT_INT32 &&
864           tensor_handle->dtype != DataType::DT_INT64) {
865         all_inputs_eligible_for_cpu_pinning = false;
866         continue;
867       }
868 
869       int64 num_elements;
870       TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
871       if (num_elements > 64) {
872         all_inputs_eligible_for_cpu_pinning = false;
873       }
874     }
875   }
876 
877   // Ops without inputs are usually ops that generate a tensor in some way and
878   // usually require being present on whatever device they are scheduled on
879   // - for e.g. VarHandleOp or _Recv).
880   // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
881   // an op, but there is a GPU kernel?
882   if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
883     VLOG(1) << "Forcing op " << op->Name()
884             << " to be on the CPU since all input tensors have an "
885                "int32/int64 dtype, and are small (less than 64 elements).";
886     op->SetDevice(ctx->HostCPU());
887   }
888 
889   return Status::OK();
890 }
891 }  // namespace
892 
EagerExecute(EagerOperation * op,gtl::InlinedVector<TensorHandle *,2> * retvals,int * num_retvals)893 Status EagerExecute(EagerOperation* op,
894                     gtl::InlinedVector<TensorHandle*, 2>* retvals,
895                     int* num_retvals) {
896   TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
897 
898   bool op_is_local = IsLocal(op->EagerContext(), op->Device());
899 
900   if (op_is_local) {
901     return EagerLocalExecute(op, retvals, num_retvals);
902   }
903 
904   if (op->EagerContext()->LogDevicePlacement()) {
905     LOG(INFO) << "Executing op " << op->Name() << " in device "
906               << op->Device()->name();
907   }
908 
909   return EagerRemoteExecute(op, retvals->data(), num_retvals);
910 }
911 
EagerKernelExecute(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & op_inputs,KernelAndDevice * kernel,NodeExecStats * maybe_stats,StepStats * maybe_step_stats,GraphCollector * graph_collector,TensorHandle ** retvals,int num_retvals)912 Status EagerKernelExecute(EagerContext* ctx,
913                           const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
914                           KernelAndDevice* kernel, NodeExecStats* maybe_stats,
915                           StepStats* maybe_step_stats,
916                           GraphCollector* graph_collector,
917                           TensorHandle** retvals, int num_retvals) {
918   std::vector<Tensor> outputs(1);
919 
920   // If there are multiple references to a TensorHandle in 'op_inputs' we must
921   // increment the reference count of the corresponding Tensor or risk it being
922   // overwritten during kernel execution. The reference count is incremented
923   // below when we insert a copy of the Tensor into protected_tensors, and will
924   // be decremented once execution is complete.
925   std::vector<tensorflow::Tensor> protected_tensors;
926   for (int i = 0; i < op_inputs.size(); ++i) {
927     if (!op_inputs[i]->RefCountIsOne()) {
928       const Tensor* input_tensor = nullptr;
929       TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
930       protected_tensors.push_back(*input_tensor);
931     }
932   }
933 
934   gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
935   for (int i = 0; i < op_inputs.size(); ++i) {
936     TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
937   }
938 
939   //  TODO(apassos) figure out how to record stats for ops which are a part of
940   //  functions.
941   // TODO(agarwal): change Run to take vector of handles ?
942   ScopedStepContainer* container = ctx->StepContainer();
943   if (container == nullptr) {
944     TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
945                                    maybe_step_stats, graph_collector));
946   } else {
947     TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
948                                    maybe_stats, maybe_step_stats,
949                                    graph_collector));
950   }
951   if (graph_collector != nullptr) {
952     mutex_lock ml(*ctx->MetadataMu());
953     {
954       GraphCollector* collector = ctx->GetGraphCollector();
955       mutex_lock mll(collector->mu);
956 
957       // Adding to partition graphs for backward compatibility.
958       for (const auto& graph : collector->partitioned_graphs) {
959         *ctx->RunMetadataProto()->add_partition_graphs() = graph;
960       }
961 
962       if (collector->dirty) {
963         auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
964         *function_graphs->mutable_post_optimization_graph() =
965             collector->optimized_graph;
966         *function_graphs->mutable_pre_optimization_graph() =
967             collector->raw_graph;
968         for (const auto& graph : collector->partitioned_graphs) {
969           *function_graphs->add_partition_graphs() = graph;
970         }
971       }
972 
973       collector->ClearGraphs();
974     }
975   }
976   if (maybe_stats != nullptr) {
977     int64 nanos = Env::Default()->NowNanos();
978     maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
979                                        maybe_stats->all_start_micros());
980     maybe_stats->set_op_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
981     maybe_stats->set_all_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
982                                         maybe_stats->all_start_micros());
983     maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
984     if (ctx->ShouldStoreStepStats()) {
985       mutex_lock ml(*ctx->MetadataMu());
986       {
987         auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
988         // Lazily initialize the RunMetadata with information about all devices
989         // if this is the first call.
990         while (step_stats->dev_stats_size() < ctx->devices()->size()) {
991           step_stats->add_dev_stats();
992         }
993         // Find the current device's index.
994         // If device is a nullptr (we are running a function without explicitly
995         // requested device), attribute the function runtime to CPU.
996         Device* attribution_device = kernel->device();
997         if (attribution_device == nullptr) {
998           attribution_device = ctx->HostCPU();
999         }
1000         int device_idx = 0;
1001         for (int i = 0; i < ctx->devices()->size(); ++i) {
1002           if (ctx->devices()->at(i) == attribution_device) {
1003             device_idx = i;
1004             break;
1005           }
1006         }
1007         // Populate the device stats for this device.
1008         auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
1009         dev_stats->set_device(attribution_device->name());
1010         *dev_stats->add_node_stats() = *maybe_stats;
1011       }
1012     }
1013   }
1014   DCHECK_EQ(num_retvals, outputs.size());
1015   for (int i = 0; i < num_retvals; ++i) {
1016     if (retvals[i] == nullptr) {
1017       retvals[i] =
1018           new TensorHandle(outputs[i], /* d= */ kernel->OutputDevice(i),
1019                            /* op_device= */ kernel->device(), ctx);
1020     } else {
1021       // In the async case, the retval is not a nullptr, and its device is
1022       // already set since all TensorHandles always have their device set
1023       // (potentially to nullptr) during construction.
1024       DCHECK_EQ(kernel->device(), retvals[i]->op_device());
1025       DCHECK_EQ(kernel->OutputDevice(i), retvals[i]->device());
1026 
1027       retvals[i]->SetTensor(outputs[i]);
1028     }
1029   }
1030   return Status::OK();
1031 }
1032 
1033 namespace {
1034 
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,Device * dstd,TensorHandle ** result)1035 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
1036                               TensorHandle** result) {
1037   TF_RETURN_IF_ERROR(ctx->GetStatus());
1038   if (ctx->Async()) {
1039     // Note that `h` may not be currently ready. However execution order will
1040     // make sure that `h` is ready before the copy is actually done.
1041     CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
1042     TensorHandle* output = node->dst();
1043     // Note that calling Add makes `node` accessible by the EagerExecutor
1044     // thread. So further accesses need to be thread-safe.
1045     ctx->ExecutorAdd(node);
1046     *result = output;
1047     return Status::OK();
1048   } else {
1049     TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, result));
1050     return Status::OK();
1051   }
1052 }
1053 
ExecuteSend(EagerContext * ctx,tensorflow::Device * device,TensorHandle * h,StringPiece wire_id,const string & recv_device)1054 Status ExecuteSend(EagerContext* ctx, tensorflow::Device* device,
1055                    TensorHandle* h, StringPiece wire_id,
1056                    const string& recv_device) {
1057   const tensorflow::AttrTypeMap* types;
1058   bool is_function = false;
1059   TF_RETURN_IF_ERROR(
1060       tensorflow::AttrTypeMapForOp("_Send", &types, &is_function));
1061   DCHECK(!is_function);
1062   tensorflow::EagerOperation op(ctx, "_Send", /*is_function=*/false, types);
1063 
1064   op.AddInput(h);
1065 
1066   op.SetDevice(device);
1067 
1068   op.MutableAttrs()->Set("tensor_name", wire_id);
1069   op.MutableAttrs()->Set("send_device", device->name());
1070   op.MutableAttrs()->Set(
1071       "send_device_incarnation",
1072       static_cast<int64>(device->attributes().incarnation()));
1073   op.MutableAttrs()->Set("recv_device", recv_device);
1074   op.MutableAttrs()->Set("client_terminated", false);
1075 
1076   op.MutableAttrs()->Set("T", h->dtype);
1077 
1078   int num_outputs = 0;
1079   gtl::InlinedVector<TensorHandle*, 2> retvals;
1080 
1081   return EagerExecute(&op, &retvals, &num_outputs);
1082 }
1083 
ExecuteRecv(EagerContext * ctx,tensorflow::Device * device,DataType dtype,StringPiece wire_id,const string & send_device,int64 send_device_incarnation,TensorHandle ** result)1084 Status ExecuteRecv(EagerContext* ctx, tensorflow::Device* device,
1085                    DataType dtype, StringPiece wire_id,
1086                    const string& send_device, int64 send_device_incarnation,
1087                    TensorHandle** result) {
1088   const tensorflow::AttrTypeMap* types;
1089   bool is_function = false;
1090   TF_RETURN_IF_ERROR(
1091       tensorflow::AttrTypeMapForOp("_Recv", &types, &is_function));
1092   DCHECK(!is_function);
1093   tensorflow::EagerOperation op(ctx, "_Recv", /*is_function=*/false, types);
1094 
1095   op.SetDevice(device);
1096 
1097   op.MutableAttrs()->Set("tensor_name", wire_id);
1098   op.MutableAttrs()->Set("send_device", send_device);
1099   op.MutableAttrs()->Set("send_device_incarnation", send_device_incarnation);
1100   op.MutableAttrs()->Set("recv_device", device->name());
1101   op.MutableAttrs()->Set("client_terminated", false);
1102 
1103   op.MutableAttrs()->Set("tensor_type", dtype);
1104 
1105   int num_outputs = 1;
1106   gtl::InlinedVector<TensorHandle*, 2> retvals(num_outputs);
1107 
1108   TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_outputs));
1109 
1110   *result = retvals.at(0);
1111 
1112   return Status::OK();
1113 }
1114 
1115 // This gets a unique wire ID. We add a random identifier so that if the
1116 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()1117 string GetUniqueWireID() {
1118   static tensorflow::uint64 random_seed = random::New64();
1119   static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
1120   static tensorflow::int64 wireid GUARDED_BY(wireid_mutex) = 0;
1121   tensorflow::mutex_lock l(wireid_mutex);
1122   return strings::StrCat(random_seed, "_", wireid++);
1123 }
1124 
1125 }  // namespace
1126 
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,const char * device_name,TensorHandle ** result)1127 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1128                          const char* device_name, TensorHandle** result) {
1129   tensorflow::Device* send_device = h->device();
1130 
1131   if (send_device == nullptr) {
1132     send_device = ctx->HostCPU();
1133   }
1134 
1135   bool sender_is_local = IsLocal(ctx, send_device);
1136 
1137   tensorflow::Device* recv_device;
1138   TF_RETURN_IF_ERROR(FindDeviceFromName(ctx, device_name, &recv_device));
1139 
1140   bool recver_is_local = IsLocal(ctx, recv_device);
1141 
1142   if (sender_is_local && recver_is_local) {
1143     return LocalEagerCopyToDevice(h, ctx, recv_device, result);
1144   } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
1145     return EagerRemoteSendTensor(ctx, h, recv_device, result);
1146   } else {
1147     string wire_id = GetUniqueWireID();
1148 
1149     TF_RETURN_IF_ERROR(
1150         ExecuteSend(ctx, send_device, h, wire_id, recv_device->name()));
1151 
1152     return ExecuteRecv(ctx, recv_device, h->dtype, wire_id, send_device->name(),
1153                        send_device->attributes().incarnation(), result);
1154   }
1155 }
1156 }  // namespace tensorflow
1157