• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17 
18 #include <cstddef>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/platform/platform.h"
30 // clang-format on
31 
32 #include "absl/strings/match.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/compiler/jit/defs.h"
36 #include "tensorflow/core/common_runtime/device.h"
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/eager/context.h"
39 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
40 #include "tensorflow/core/common_runtime/eager/execute_node.h"
41 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
42 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
43 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
44 #include "tensorflow/core/common_runtime/colocation_graph.h"
45 #include "tensorflow/core/framework/dataset.h"
46 #include "tensorflow/core/framework/function.h"
47 #include "tensorflow/core/framework/logging.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/tensor_reference.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/util/device_name_utils.h"
54 #if !defined(IS_MOBILE_PLATFORM)
55 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
56 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
57 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
58 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
59 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
60 #endif  // IS_MOBILE_PLATFORM
61 #include "tensorflow/core/framework/step_stats.pb.h"
62 #include "tensorflow/core/framework/tensor.h"
63 #include "tensorflow/core/framework/types.h"
64 #include "tensorflow/core/lib/core/status.h"
65 #include "tensorflow/core/lib/gtl/cleanup.h"
66 #include "tensorflow/core/lib/gtl/flatset.h"
67 #include "tensorflow/core/lib/gtl/inlined_vector.h"
68 #include "tensorflow/core/lib/random/random.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/mutex.h"
71 #include "tensorflow/core/util/ptr_util.h"
72 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
73 
74 namespace tensorflow {
75 
76 namespace {
77 
DeviceNameOrUnspecified(Device * device)78 const string& DeviceNameOrUnspecified(Device* device) {
79   static string* unspecified_string = new string("<unspecified>");
80   return (device == nullptr) ? *unspecified_string : device->name();
81 }
82 
83 // This function expects *handle to point to an existing tensor handle that is
84 // currently on "handle_device", but where the operation expects that input to
85 // reside on "expected_input_device".  The function will arrange for this
86 // transfer to happen and will return OK on success and will storage a new
87 // handle to the equivalent tensor on the correct device in "*result".  Or if an
88 // error is encountered, it will return a non-OK status and set "*result" to
89 // nullptr.
90 //
91 // `op_device` is passed in explicitly because `op->device()` might be
92 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)93 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
94                                  Device* op_device,
95                                  TensorHandle* handle,  // op->Inputs()[i]
96                                  int i, Device* handle_device,
97                                  Device* expected_input_device,
98                                  TensorHandle** result) {
99   // Should only be called when these don't match
100   DCHECK(expected_input_device != handle_device);
101   *result = nullptr;
102   const string& op_device_name = DeviceNameOrUnspecified(op_device);
103 
104   switch (ctx->GetDevicePlacementPolicy()) {
105     case DEVICE_PLACEMENT_SILENT_FOR_INT32:
106       // TODO(xpan): See if we could bubble python related error up
107       // to python level.
108       if (handle->dtype == DT_INT32) {
109         // Note: enabling silent copies of int32 tensors to match behavior
110         // of graph mode.
111         break;
112       }
113       TF_FALLTHROUGH_INTENDED;
114     case DEVICE_PLACEMENT_EXPLICIT:
115       return errors::InvalidArgument(
116           "Tensors on conflicting devices:"
117           " cannot compute ",
118           op->Name(), " as input #", i, " was expected to be on ",
119           expected_input_device->name(), " but is actually on ",
120           handle_device->name(), " (operation running on ", op_device_name, ")",
121           " Tensors can be copied explicitly using:"
122           " `with tf.device(device_name): x = tf.identity(x)`"
123           " or transparently copied by using"
124           " tf.config.experimental.set_device_policy('silent')."
125           " Copying tensors between devices may slow down your model");
126     case DEVICE_PLACEMENT_WARN:
127       LOG(WARNING) << "before computing " << op->Name() << " input #" << i
128                    << " was expected to be on " << expected_input_device->name()
129                    << " but is actually on " << handle_device->name()
130                    << " (operation running on " << op_device_name
131                    << "). This triggers a copy which can be a performance "
132                       "bottleneck.";
133       break;
134     case DEVICE_PLACEMENT_SILENT:  // Do nothing.
135       break;
136   }
137   // We are only here if the policy is warn or silent copies, so we should
138   // trigger a copy.
139   TensorHandle* result_handle = nullptr;
140   profiler::TraceMe activity(
141       [&] {
142         return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
143                             " to ", expected_input_device->name());
144       },
145       profiler::TraceMeLevel::kInfo);
146   Status status =
147       EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
148                         ctx->MirrorTensors(), &result_handle);
149   activity.Stop();
150   if (!status.ok()) {
151     return errors::Internal("Failed copying input tensor from ",
152                             handle_device->name(), " to ",
153                             expected_input_device->name(), " in order to run ",
154                             op->Name(), ": ", status.error_message());
155   }
156 
157   *result = result_handle;
158 
159   return Status::OK();
160 }
161 
162 // `op_device_name` the name of the device on which the op will run, if any.
163 // For functions running using function library runtime, the device can be
164 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)165 Status ValidateInputTypeAndPlacement(
166     EagerContext* ctx, EagerOperation* op,
167     const core::RefCountPtr<KernelAndDevice>& kernel) {
168   profiler::TraceMe activity("ValidateInputTypeAndPlacement",
169                              profiler::TraceMeLevel::kInfo);
170   const int n_inputs = op->Inputs().size();
171   if (kernel->num_inputs() != n_inputs) {
172     return errors::InvalidArgument("expected ", kernel->num_inputs(),
173                                    " inputs, got ", n_inputs);
174   }
175   const bool skip_remote_copy =
176       ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction();
177   for (int i = 0; i < n_inputs; ++i) {
178     TensorHandle* handle = op->Inputs()[i];
179     Device* expected_device = kernel->InputDevice(i);
180     Device* handle_device = handle->DeviceOrHostCPU(*ctx);
181     const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
182     // If the input is already on the right device, then nothing to do.
183     if (expected_device != handle_device && maybe_copy) {
184       TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
185                                                    handle, i, handle_device,
186                                                    expected_device, &handle));
187       op->UpdateInput(i, handle);
188       // Unref handle since it has a ref as an input now
189       handle->Unref();
190     }
191     if (handle->dtype != kernel->input_type(i)) {
192       return errors::InvalidArgument(
193           "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
194           " was expected to be a ", DataTypeString(kernel->input_type(i)),
195           " tensor but is a ", DataTypeString(handle->dtype), " tensor");
196     }
197   }
198   return Status::OK();
199 }
200 
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)201 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
202   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
203   const OpDef* op_def = nullptr;
204 
205   const FunctionDef* function_def =
206       op->EagerContext().FuncLibDef()->Find(op->Name());
207   if (function_def != nullptr) {
208     op_def = &(function_def->signature());
209   } else {
210     TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
211   }
212 
213   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
214 
215   return Status::OK();
216 }
217 
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)218 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
219                                                const tensorflow::Fprint128& b) {
220   return {tensorflow::FingerprintCat64(a.low64, b.low64),
221           tensorflow::FingerprintCat64(a.high64, b.high64)};
222 }
223 
FingerprintCat128(const tensorflow::Fprint128 & a,const int64 b)224 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
225                                                const int64 b) {
226   auto x = tensorflow::FingerprintCat64(a.low64, b);
227   return {x, tensorflow::FingerprintCat64(a.high64, x)};
228 }
229 
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)230 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
231                          Device** result) {
232   Device* cpu_device = ctx.HostCPU();
233   string device_name;
234   if (tensor_handle->IsRemote()) {
235     Device* device = tensor_handle->device();
236     device_name = device != nullptr ? device->name() : cpu_device->name();
237     *result = (device == nullptr ? cpu_device : device);
238   } else if (tensor_handle->dtype == DT_RESOURCE) {
239     // Use the resource's actual device because it is the device that will
240     // influence partitioning the multi-device function.
241     const Tensor* tensor;
242     // TODO(fishx): Avoid blocking here.
243     TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
244     const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
245     device_name = handle.device();
246 
247     Device* input_device;
248     TF_RETURN_IF_ERROR(
249         ctx.FindDeviceFromName(device_name.c_str(), &input_device));
250     *result = input_device;
251   } else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
252     *result = cpu_device;
253   } else {
254     Device* device = tensor_handle->device();
255     device_name = device != nullptr ? device->name() : cpu_device->name();
256     *result = (device == nullptr ? cpu_device : device);
257   }
258   return Status::OK();
259 }
260 
261 // Appends a TensorShape object to Fprint128 hash.
262 // For best performance, we would like to avoid dynamic memory allocation in
263 // this function.
264 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
265 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)266 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
267                                     Fprint128* fingerprint) {
268   if (shape.unknown_rank()) {
269     char c = '?';
270     *fingerprint = FingerprintCat128(*fingerprint, c);
271   } else {
272     for (int i = 0; i < shape.dims(); i++) {
273       int64 dim = shape.dim_size(i);
274       *fingerprint = FingerprintCat128(*fingerprint, dim);
275     }
276   }
277 }
278 
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)279 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
280                           bool* compile_with_xla) {
281   if (!op->is_function()) {
282     *compile_with_xla = false;
283     return Status::OK();
284   }
285 
286   if (op->remote_func_params().has_value() &&
287       op->remote_func_params().value().step_id.has_value()) {
288     // If the op is a component of a multi-device function, don't compile it
289     // with XLA.
290     *compile_with_xla = false;
291     return Status::OK();
292   }
293 
294   // Does node have an explicit request to compile or not?
295   Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla);
296   if (status.ok()) {
297     DVLOG(2) << "Caller explicitly requested "
298              << (*compile_with_xla ? "" : "not ")
299              << "to compile with XLA: " << op->DebugString();
300     return Status::OK();
301   }
302 
303   // Does FunctionDef have an explicit request to compile or not?
304   const FunctionDef* function_def =
305       ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
306   if (function_def == nullptr) {
307     return errors::NotFound("Failed to find function '", op->Name(), "'");
308   }
309 
310   status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr,
311                        compile_with_xla);
312   if (status.ok()) {
313     DVLOG(2) << "Function definition explicitly specifies "
314              << (*compile_with_xla ? "" : "not ") << "to compile with XLA";
315     return Status::OK();
316   }
317 
318   // No explicit requests. Compile for XLA devices by default.
319   if (op->GetDeviceParsedName().type == "TPU" ||
320       op->GetDeviceParsedName().type == "XLA_GPU" ||
321       op->GetDeviceParsedName().type == "XLA_CPU") {
322     DVLOG(2) << "Compiling " << op->Name()
323              << " with XLA because it is running on an XLA device "
324              << op->GetDeviceParsedName().type;
325     *compile_with_xla = true;
326   } else {
327     *compile_with_xla = false;
328   }
329 
330   return Status::OK();
331 }
332 
333 // There are a lot of references to devices in this function and around.
334 // Here is what they mean:
335 //  EagerOperation::Device(): The device on which the user requested the op
336 //    be executed, except if we had to change the device due to resource inputs
337 //    or CPU pinning. If the user did not request a device, the op does not
338 //    take resources, and we did not pin it to CPU, the device can be nullptr.
339 //  KernelAndDevice::Device(): The first time we see an op (combined with
340 //    its attributes), we need to create a KernelAndDevice object for it.
341 //    If op->Device() is a nullptr, we select a device for the op when
342 //    creating the KernelAndDevice. A concrete device will always be selected
343 //    here except when `op` is a function to be executed using function library
344 //    runtime. In this case, we don't select a device because running
345 //    a function with explicitly requested device has different behavior than
346 //    running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)347 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
348                          int* num_retvals) {
349   MEMDEBUG_CACHE_OP(op->op_name());
350   profiler::TraceMe activity(
351       [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
352       profiler::TraceMeLevel::kInfo);
353   EagerContext& ctx = op->EagerContext();
354   auto& executor = op->Executor();
355   TF_RETURN_IF_ERROR(executor.status());
356   Device* device = op->Device();
357 
358   Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName());
359 
360   std::vector<Device*> input_dev_ptrs;
361   std::unordered_map<int, DtypeAndPartialTensorShape>
362       input_resource_variable_dtypes_and_shapes;
363   // We can eliminate some overhead by running simple functions using regular
364   // CallOp kernel. However, it is tricky to figure out which functions should
365   // be run using CallOp. Also, currently CallOp runs neither optimization
366   // passes (needed for TPU/XLA) nor grappler.
367   // Here are some cases where a function should be run in multi-device mode:
368   //  - Function takes at least two resources on different devices.
369   //  - Function takes a resource on deviceA and a body op explicitly placed
370   //  on deviceB.
371   //  - Function has a colocation constraint.
372   //  - Function has an explicit device annotation (which might not be using
373   //    full canonical device name) different from op_device. Note that false
374   //    positives are ok.
375   //  - Function has a node or a (node) attribute that can potentially make
376   //    the function multi-device after a rewrite pass (e.g. various XLA/TPU
377   //    special nodes and attributes)
378   if (op->is_function()) {
379     profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
380                                profiler::TraceMeLevel::kInfo);
381     input_dev_ptrs.reserve(op->Inputs().size());
382     // When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on
383     // local devices, since we execute a remote function through worker service,
384     // which doesn't accept remote inputs.
385     for (int i = 0; i < op->Inputs().size(); i++) {
386       TensorHandle* input = op->Inputs()[i];
387       if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
388         TensorHandle* handle = nullptr;
389         TF_RETURN_IF_ERROR(EagerCopyToDevice(
390             input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device,
391             ctx.MirrorTensors(), &handle));
392         op->UpdateInput(i, handle);
393         // Unref handle since it has a ref as an input now
394         handle->Unref();
395         input = handle;
396       }
397 
398       // Get device for this input, and add it to 'cache_key'.
399       Device* input_device;
400       TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
401       input_dev_ptrs.push_back(input_device);
402       cache_key =
403           FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
404 
405       // If input is a ResourceHandle, get its resource handle dtypes and shapes
406       // and add them to 'cache_key'.
407       if (input->dtype == DT_RESOURCE) {
408         // We only care about data type and shape for resource variable inputs.
409         // But we have no way to tell if input is resource variable (other than
410         // looking it up in ResourceMgr, which is slow). So we just get
411         // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
412         // resource_dtypes_and_shapes is not empty, take the first element.
413         std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
414         TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
415             &resource_dtypes_and_shapes));
416         if (!resource_dtypes_and_shapes.empty()) {
417           const DtypeAndPartialTensorShape& dtype_and_shape =
418               resource_dtypes_and_shapes.at(0);
419           input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
420 
421           // Add _Arg index, dtype and shape to "cache_key".
422           cache_key = FingerprintCat128(cache_key, i);
423           DataType dtype = dtype_and_shape.dtype;
424           cache_key = FingerprintCat128(cache_key, dtype);
425           AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
426         }
427       }
428     }
429   }
430 
431   core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
432   if (kernel == nullptr) {
433     DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
434              << DeviceNameOrUnspecified(op->Device());
435     bool run_function_with_flr = false;
436     if (op->is_function()) {
437       bool compile_with_xla;
438       TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
439       if (compile_with_xla) {
440         // Note that it is not ideal, but currently correct, to set this
441         // attribute after computing the kernel cache key above.
442         // Note: If the attribute is already set to true, this is a noop.
443         op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
444       } else {
445         run_function_with_flr = true;
446       }
447     }
448 
449     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
450     if (device == nullptr) {
451       PrioritizedDeviceTypeVector supported_devs;
452       TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
453           ctx.prioritized_device_type_list(), ndef, &supported_devs,
454           &ctx.HostCPU()->parsed_name()));
455       if (supported_devs.empty()) {
456         return errors::NotFound("Could not find valid device for node.\nNode:",
457                                 FormatNodeDefForError(ndef),
458                                 "\nAll kernels registered for op ", ndef.op(),
459                                 " :\n", KernelsRegisteredForOp(ndef.op()));
460       }
461       TF_RETURN_IF_ERROR(ctx.SelectDevice(op->GetDeviceParsedName(),
462                                           supported_devs, DT_INVALID, &device));
463 
464       DVLOG(1) << "Placer place op [" << op->Name()
465                << "] on device: " << device->name();
466       DVLOG(4) << "Available kernels for " << op->Name() << "are "
467                << KernelsRegisteredForOp(op->Name());
468       op->SetDevice(device);
469     }
470     if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
471       string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
472                                    DeviceNameOrUnspecified(device));
473       if (!logging::LogToListeners(msg)) {
474         LOG(INFO) << msg;
475       }
476     }
477 
478     FunctionLibraryRuntime* flr =
479         device == nullptr ? nullptr : ctx.func_lib(device);
480     if (device != nullptr && flr == nullptr) {
481       return errors::Unavailable(
482           "Unable to find a FunctionLibraryRuntime corresponding to device ",
483           device->name());
484     }
485     auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
486                                                                : ctx.runner();
487     GraphCollector* graph_collector = nullptr;
488     if (ctx.ShouldStoreGraphs()) {
489       graph_collector = ctx.GetGraphCollector();
490     }
491     // Treat the function as multi_device only when we are not compiling
492     // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
493     // will create an XlaLaunchOp kernel to compile and run the function.
494     if (run_function_with_flr) {
495       // Multi-device functions don't use the rendezvous from eager context.
496       // If we use that rendezvous, multiple concurrent calls to the same
497       // function will likely result in collisions. However, this also means
498       // that we don't support legitimate sending/receiving across function
499       // boundary.
500       DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
501                << "Full node_def=" << ndef.DebugString();
502       std::function<int64()> get_op_id = nullptr;
503 #if !defined(IS_MOBILE_PLATFORM)
504       if (ctx.LazyCopyFunctionRemoteInputs()) {
505         get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
506       }
507 #endif  // IS_MOBILE_PLATFORM
508       kernel.reset(new KernelAndDeviceFunc(
509           flr, ctx.pflr(), std::move(input_dev_ptrs),
510           std::move(input_resource_variable_dtypes_and_shapes), runner,
511           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
512           [&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },
513           get_op_id));
514     } else {
515       DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
516                << ". Full node_def=" << ndef.DebugString();
517       kernel.reset(new KernelAndDeviceOp(
518           ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
519           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
520     }
521 
522     TF_RETURN_IF_ERROR(kernel->Init(ndef, graph_collector));
523 
524     if (op->is_function()) {
525       ctx.AddKernelToCache(cache_key, kernel.get());
526     } else {
527       // Exclude tf.data op kernels from being cached. The reason for this is
528       // that tf.data op kernels that accept a user-defined function will have a
529       // unique cache key every time they are executed (because the user-defined
530       // function is traced every time). Caching such kernels provides no
531       // benefit and in some cases results in linear memory growth of use
532       // programs that build input pipeline graphs in a loop.
533       const OpDef* op_def;
534       TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
535       if (!data::DatasetOpKernel::IsDatasetOp(op_def)) {
536         ctx.AddKernelToCache(cache_key, kernel.get());
537       }
538     }
539   }
540   const DataTypeVector& output_dtypes = kernel->output_dtypes();
541   const size_t num_outputs = static_cast<int>(output_dtypes.size());
542   if (num_outputs > *num_retvals) {
543     return errors::InvalidArgument("Expecting ", num_outputs,
544                                    " outputs, but *num_retvals is ",
545                                    *num_retvals);
546   }
547   *num_retvals = num_outputs;
548   TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
549 
550   GraphCollector* graph_collector = nullptr;
551   if (ctx.ShouldStoreGraphs()) {
552     graph_collector = ctx.GetGraphCollector();
553   }
554 
555   const bool async = executor.Async();
556   for (int i = 0; i < num_outputs; ++i) {
557     TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
558         async,
559         /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
560         /* op_device= */ kernel->device(),
561         /* resource_device= */ kernel->OutputResourceDevice(i),
562         output_dtypes[i], &ctx, &retvals[i]));
563   }
564 
565   Status s;
566   if (async) {
567     auto node = absl::make_unique<ExecuteNode>(
568         &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
569         graph_collector, output_dtypes, op->GetCancellationManager(),
570         executor.Async(), absl::Span<TensorHandle*>(retvals, num_outputs));
571     // For async mode, execution order will make sure that all
572     // input handles are ready before executing them.
573     // TODO(b/137118203): Consider executing "cheap" kernels inline for
574     // performance.
575     s = executor.AddOrExecute(std::move(node));
576   } else {
577     ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(),
578                      std::move(kernel), graph_collector, output_dtypes,
579                      op->GetCancellationManager(), executor.Async(),
580                      {retvals, num_outputs});
581     s = executor.SyncExecute(&node);
582   }
583   // Since the operation failed, we need to Unref any outputs that were
584   // allocated.
585   if (!s.ok()) {
586     for (int i = 0; i < num_outputs; ++i) {
587       retvals[i]->Unref();
588     }
589   }
590 
591   return s;
592 }
593 
594 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)595 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
596   EagerContext& ctx = op->EagerContext();
597 
598   remote_op->set_id(ctx.RemoteMgr()->NextOpId());
599   remote_op->set_name(op->Name());
600 
601   op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
602   remote_op->set_device(op->Device()->name());
603   remote_op->set_is_function(op->is_function());
604 }
605 
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)606 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
607                                     const DataTypeVector& output_dtypes,
608                                     TensorHandle** retvals) {
609   if (remote_op.name() == "VarHandleOp") {
610     if (output_dtypes.size() != 1) {
611       return errors::Internal("VarHandleOp should only have one output.");
612     }
613     if (output_dtypes[0] != DT_RESOURCE) {
614       return errors::Internal(
615           "The output of VarHandleOp should be a DT_RESOURCE.");
616     }
617     AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
618     const AttrValue* dtype;
619     TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
620     const AttrValue* shape;
621     TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
622     retvals[0]->SetResourceHandleDtypeAndShape(
623         {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
624   }
625   return Status::OK();
626 }
627 
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)628 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
629                           int* num_retvals) {
630   EagerContext& ctx = op->EagerContext();
631 
632   // TODO(fishx): Remove following code when lazy tensor copy is ready.
633   if (op->Device() == nullptr) {
634     tensorflow::Device* device = nullptr;
635     string device_name = op->GetDeviceName();
636     TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
637     op->SetDevice(device);
638   }
639 
640   core::RefCountPtr<eager::EagerClient> eager_client;
641   uint64 context_id = ctx.GetContextId();
642   TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
643   string remote_task;
644   if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
645     return errors::InvalidArgument(
646         "Unable to find remote task corresponding to device ",
647         op->Device()->name());
648   }
649 
650   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
651   request->set_context_id(context_id);
652 
653   eager::Operation* remote_op = request->add_queue()->mutable_operation();
654 
655   {
656     profiler::TraceMe activity("CopyInputToExpectedDevice",
657                                profiler::TraceMeLevel::kInfo);
658     const bool eagerly_copy_function_remote_inputs =
659         !ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
660     for (int i = 0; i < op->Inputs().size(); i++) {
661       tensorflow::TensorHandle* input = op->Inputs()[i];
662       tensorflow::Device* input_device = input->device();
663       const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
664       bool serialize_resource_dtype_and_shape = false;
665       if (op->Device() != input_device &&
666           // If the expected and actual devices are on the same task, don't
667           // explicitly copy, and instead depend on the copy to happen locally
668           // when the op is executed on the device.
669           !ctx.OnSameTask(op->Device(), input_device)) {
670         if (eagerly_copy_function_remote_inputs ||
671             input->DeviceOrHostCPU(ctx)->IsLocal()) {
672           tensorflow::Device* remote_cpu_device;
673           TF_RETURN_IF_ERROR(
674               ctx.CPUDeviceOnTask(op->Device(), &remote_cpu_device));
675           // TODO(b/110044833): It's possible the same tensor gets copied to the
676           // remote device repeatedly.
677           // Always copy to the remote CPU so that the actual device can be
678           // correctly determined after the kernel is selected/instantiated,
679           // since the op might have its inputs on host memory.
680           TensorHandle* handle = op->Inputs()[i];
681           Device* handle_device = handle->DeviceOrHostCPU(ctx);
682           // If the input is already on the right device, then nothing to do.
683           if (remote_cpu_device != handle_device) {
684             TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
685                 &ctx, op, op->Device(), handle, i, handle_device,
686                 remote_cpu_device, &handle));
687             op->UpdateInput(i, handle);
688             input = handle;
689             input_device = remote_cpu_device;
690             input_device_name = &remote_cpu_device->name();
691             // Unref handle since it has a ref as an input now
692             handle->Unref();
693           }
694         } else {
695           serialize_resource_dtype_and_shape =
696               (input->dtype == DT_RESOURCE) &&
697               (!input->HasResourceShapeMirror(op->Device()));
698         }
699       }
700       auto* input_handle = remote_op->add_inputs();
701       TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
702           input, input_handle, input_device, *input_device_name,
703           serialize_resource_dtype_and_shape));
704       if (!input_handle->resource_dtypes_and_shapes().empty()) {
705         auto tensor_handle_data =
706             absl::make_unique<UnshapedRemoteTensorHandleData>(
707                 input_handle->op_id(), input_handle->output_num(), remote_task,
708                 context_id, &ctx);
709         TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
710             std::move(tensor_handle_data), op->Device()));
711       }
712     }
713   }
714 
715   PrepareRemoteOp(remote_op, op);
716 
717   DataTypeVector output_dtypes;
718   TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
719 
720   const size_t num_outputs = static_cast<int>(output_dtypes.size());
721   if (num_outputs != *num_retvals) {
722     return errors::InvalidArgument(
723         "num_retvals does not match expected output dtypes");
724   }
725   *num_retvals = num_outputs;
726 
727   tensorflow::Device* op_device = op->Device();
728   const tensorflow::uint64 id = remote_op->id();
729   for (int i = 0; i < num_outputs; ++i) {
730     // TODO(nareshmodi): Change the callback to instead add the decref to a
731     // list of pending decrefs that we can send as a batch with the next
732     // execute.
733 
734     // The device_ and resource_device_ of this TensorHandle might be
735     // incorrect. It is pretty hard to make it correct because for
736     // multi-device functions, we don't know the output device until the
737     // function is instantiated. Luckily, we don't need to know the correct
738     // remote device here. We just need to know that it is remote. If we need
739     // to copy this tensor to this process, the remote end will know the
740     // correct device of this handle.
741     Status status = TensorHandle::CreateUnshapedRemoteHandle(
742         id, i, remote_task, context_id, output_dtypes[i], op_device, &ctx,
743         &retvals[i]);
744     if (!status.ok()) {
745       for (int j = 0; j < i; ++j) {
746         retvals[j]->Poison(errors::Internal(
747             "Failed to construct unshaped remote tensor handle at index ", i,
748             " for op ", op->Name()));
749       }
750       return status;
751     }
752   }
753 
754   if (ctx.LazyCopyFunctionRemoteInputs()) {
755     // Store the data type and shape of a remote resource variable on the
756     // corresponding remote TensorHandle (output of 'VarHandleOp').
757     // If the variable is an input of a remote function, the function may need
758     // the type and shape during function instantiation. When
759     // LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
760     // handle (contains the type and shape) of the variable to the default
761     // function device. Instead, we store the type and shape on eager master
762     // and sent them to the default function device along with the
763     // EnqueueRequest.
764     TF_RETURN_IF_ERROR(
765         StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
766   }
767 
768   auto& executor = op->Executor();
769   DVLOG(4) << "Execute remote eager op: " << op->Name()
770            << " (is async?: " << executor.Async() << ").";
771 
772   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
773       std::move(request), op_device, eager_client.get(),
774       op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
775       op->Inputs(), {retvals, num_outputs}));
776   Status s = executor.AddOrExecute(std::move(node));
777   // Since the operation failed, we need to Unref any outputs that were
778   // allocated.
779   if (!s.ok()) {
780     for (int i = 0; i < num_outputs; ++i) {
781       retvals[i]->Unref();
782     }
783   }
784 
785   return s;
786 }
787 #endif  // IS_MOBILE_PLATFORM
788 
789 // These ops are not pinnable since they generate data. It can be slower to
790 // generate and then copy the data instead of just generating the data on the
791 // device directly.
IsPinnableOp(const string & op_type)792 bool IsPinnableOp(const string& op_type) {
793   static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
794       "RandomUniform",
795       "RandomUniformInt",
796       "RandomStandardNormal",
797       "StatelessRandomUniform",
798       "StatelessRandomUniformInt",
799       "StatelessRandomNormal",
800   });
801 
802   // XRT ops refer to per-device handles that are not safe to move between
803   // devices.
804   return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
805          !absl::StartsWith(op_type, "XRT");
806 }
807 
808 // The Op device may be updated if:
809 // - A resource touching input is specified: all resource-touching ops run in
810 // the device the resource is, regardless of anything else that has been
811 // specified. This is identical to the graph mode behavior.
812 //
813 // - All op inputs are on the CPU, small (<64 elements) and integers
814 // (int32/int64). This can be disabled by setting the environment variable
815 // "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
MaybeUpdateOpDevice(EagerOperation * op)816 Status MaybeUpdateOpDevice(EagerOperation* op) {
817   const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
818   if (op->is_function() || exempt_ops.find(op->Name()) != exempt_ops.end()) {
819     // Don't update the device of direct function calls.
820     // Particularly, if the user did not explicitly request any device for this
821     // function, picking a device would result in this device being the default
822     // for nodes inside the function. This is undesirable for multi-device
823     // functions since the not-explicitly-placed nodes inside the body will all
824     // end up on this default device.
825     return Status::OK();
826   }
827   EagerContext& ctx = op->EagerContext();
828   bool all_inputs_eligible_for_cpu_pinning =
829       ctx.PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
830   Device* op_device = op->Device() == nullptr ? ctx.HostCPU() : op->Device();
831   for (int i = 0; i < op->Inputs().size(); ++i) {
832     TensorHandle* tensor_handle = op->Inputs()[i];
833     if (tensor_handle->dtype == DT_RESOURCE) {
834       Device* resource_device = tensor_handle->resource_device();
835       DVLOG(2) << "for op " << op->Name() << " input " << i << " "
836                << DataTypeString(tensor_handle->dtype)
837                << " input device = " << resource_device->name()
838                << ", op device = " << op_device->name();
839       // We check for `op->Device() == nullptr` because it can be later
840       // interpreted as unspecified device and a different device can
841       // be selected based on device priority. If any input to an op
842       // is a resource we must pin it to prevent different device selection.
843       // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
844       if (resource_device != op_device || op->Device() == nullptr) {
845         DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
846                  << "device of operation " << op->Name() << " to "
847                  << resource_device->name() << " because input #" << i
848                  << " is a resource in this device.";
849         op->SetDevice(resource_device);
850       }
851       all_inputs_eligible_for_cpu_pinning = false;
852       // No point in looking at other inputs. If there are other resources,
853       // they must have the same device and we already declared the op to be
854       // ineligible for CPU pinning.
855       break;
856     } else if (all_inputs_eligible_for_cpu_pinning) {
857       Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
858       DVLOG(2) << "for op " << op->Name() << " input " << i << " "
859                << DataTypeString(tensor_handle->dtype)
860                << " input device = " << input_device->name()
861                << ", op device = " << op_device->name();
862 
863       // Input is on CPU.
864       if (input_device != ctx.HostCPU()) {
865         all_inputs_eligible_for_cpu_pinning = false;
866         continue;
867       }
868 
869       if (tensor_handle->dtype != DataType::DT_INT32 &&
870           tensor_handle->dtype != DataType::DT_INT64) {
871         all_inputs_eligible_for_cpu_pinning = false;
872         continue;
873       }
874 
875       int64 num_elements;
876       TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
877       if (num_elements > 64) {
878         all_inputs_eligible_for_cpu_pinning = false;
879       }
880     }
881   }
882 
883   // Ops without inputs are usually ops that generate a tensor in some way and
884   // usually require being present on whatever device they are scheduled on
885   // - for e.g. VarHandleOp or _Recv).
886   // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
887   // an op, but there is a GPU kernel?
888   if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
889     DVLOG(1) << "Forcing op " << op->Name()
890              << " to be on the CPU since all input tensors have an "
891                 "int32/int64 dtype, and are small (less than 64 elements).";
892     op->SetDevice(ctx.HostCPU());
893   }
894 
895   return Status::OK();
896 }
897 }  // namespace
898 
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)899 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
900                     int* num_retvals) {
901   profiler::TraceMe activity(
902       [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
903       profiler::TraceMeLevel::kInfo);
904   TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
905 
906   if (!op->Executor().Async()) {
907     // In sync mode, always clear error to maintain the same behavior as before.
908     // TODO(b/141004939): Remove this.
909     op->Executor().ClearError();
910   }
911 
912   std::unique_ptr<tensorflow::EagerOperation> out_op;
913   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
914       EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
915 
916   if (op->IsLocal()) {
917     if (out_op) {
918       op = out_op.get();
919     }
920     return EagerLocalExecute(op, retvals, num_retvals);
921   }
922 
923   if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
924     string msg = strings::StrCat(
925         "Executing op ", op->Name(), " on task ",
926         DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
927     if (!logging::LogToListeners(msg)) {
928       LOG(INFO) << msg;
929     }
930   }
931 
932 #if defined(IS_MOBILE_PLATFORM)
933   return errors::Unimplemented(
934       "Eager's remote execution is not available on mobile devices.");
935 #else   // !IS_MOBILE_PLATFORM
936   if (out_op) {
937     op = out_op.get();
938   }
939   return EagerRemoteExecute(op, retvals, num_retvals);
940 #endif  // !IS_MOBILE_PLATFORM
941 }
942 
943 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals)944 Status EagerKernelExecute(
945     EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
946     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
947     const core::RefCountPtr<KernelAndDevice>& kernel,
948     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
949     absl::Span<TensorHandle*> retvals) {
950   profiler::TraceMe activity("EagerKernelExecute",
951                              profiler::TraceMeLevel::kInfo);
952   std::vector<Tensor> outputs(1);
953 
954   ExecuteNodeArgs inputs(op_inputs.size());
955   TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs));
956   // TODO(apassos) figure out how to record stats for ops which are a part of
957   // functions.
958   // TODO(b/111859745): When we support recovering from kernel/device errors, we
959   // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
960   // device. We don't call it now because it is an unneeded overhead (it
961   // acquires a lock) and we can't recover from errors anyway.
962   ScopedStepContainer* container = ctx->StepContainer();
963   if (container == nullptr) {
964     TF_RETURN_IF_ERROR(kernel->Run(inputs, &outputs, cancellation_manager,
965                                    remote_func_params));
966   } else {
967     TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
968                                    cancellation_manager, remote_func_params));
969   }
970   if (graph_collector != nullptr) {
971     mutex_lock ml(*ctx->MetadataMu());
972     {
973       GraphCollector* collector = ctx->GetGraphCollector();
974       mutex_lock mll(collector->mu);
975 
976       // Adding to partition graphs for backward compatibility.
977       for (const auto& graph : collector->partitioned_graphs) {
978         *ctx->RunMetadataProto()->add_partition_graphs() = graph;
979       }
980 
981       if (collector->dirty) {
982         auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
983         *function_graphs->mutable_post_optimization_graph() =
984             collector->optimized_graph;
985         *function_graphs->mutable_pre_optimization_graph() =
986             collector->raw_graph;
987         for (const auto& graph : collector->partitioned_graphs) {
988           *function_graphs->add_partition_graphs() = graph;
989         }
990       }
991 
992       collector->ClearGraphs();
993     }
994   }
995   DCHECK_EQ(retvals.size(), outputs.size());
996   for (int i = 0; i < retvals.size(); ++i) {
997     DCHECK_EQ(kernel->device(), retvals[i]->op_device());
998     DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
999               retvals[i]->device());
1000 
1001     TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
1002   }
1003   return Status::OK();
1004 }
1005 
1006 namespace {
1007 
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,TensorHandle ** result)1008 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1009                               EagerExecutor* executor, Device* dstd,
1010                               TensorHandle** result) {
1011   TF_RETURN_IF_ERROR(executor->status());
1012   TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
1013       true, ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype,
1014       ctx, result));
1015 
1016   // Note that `h` may not be currently ready. However execution order will
1017   // make sure that `h` is ready before the copy is actually done.
1018   std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, *ctx));
1019   Status s = executor->AddOrExecute(std::move(node));
1020   // Since the operation failed, we need to Unref any outputs that were
1021   // allocated.
1022   if (!s.ok()) {
1023     (*result)->Unref();
1024   }
1025 
1026   return s;
1027 }
1028 
1029 }  // namespace
1030 
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1031 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1032                          EagerExecutor* executor, Device* device, bool mirror,
1033                          TensorHandle** result) {
1034   Device* send_device = h->DeviceOrHostCPU(*ctx);
1035 
1036   bool sender_is_local = send_device->IsLocal();
1037 
1038   bool recver_is_local = device->IsLocal();
1039 
1040   if (!executor->Async()) {
1041     // In sync mode, always clear error to maintain the same behavior as before.
1042     // TODO(b/141004939): Remove this.
1043     executor->ClearError();
1044   }
1045 
1046   if (sender_is_local && recver_is_local) {
1047     return LocalEagerCopyToDevice(h, ctx, executor, device, result);
1048   } else {
1049 #if defined(IS_MOBILE_PLATFORM)
1050     return errors::Unimplemented(
1051         "Eager's remote execution is not available on mobile devices.");
1052 #else   // !IS_MOBILE_PLATFORM
1053     if (mirror) {
1054       if (h->HasRemoteMirror(device)) {
1055         h->Ref();
1056         *result = h;
1057         return Status::OK();
1058       }
1059     }
1060     uint64 recv_op_id = 0;
1061     if (recver_is_local) {
1062       TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
1063           true, /* d= */ device, /* op_device= */ device,
1064           /*resource_device=*/nullptr, h->dtype, ctx, result));
1065     } else {
1066       uint64 context_id = ctx->GetContextId();
1067       string remote_task;
1068       if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1069         return errors::InvalidArgument(
1070             "Unable to find remote task corresponding to device ",
1071             device->name());
1072       }
1073       recv_op_id = ctx->RemoteMgr()->NextOpId();
1074       auto tensor_handle_data =
1075           absl::make_unique<UnshapedRemoteTensorHandleData>(
1076               recv_op_id, 0, remote_task, context_id, ctx);
1077       if (mirror) {
1078         TF_RETURN_IF_ERROR(
1079             h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device));
1080         h->Ref();
1081         *result = h;
1082       } else {
1083         TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
1084             std::move(tensor_handle_data), h->dtype, device, ctx, result));
1085       }
1086     }
1087     auto node = absl::make_unique<eager::RemoteCopyNode>(
1088         ctx, executor, h, result[0], device, recv_op_id);
1089     Status s = executor->AddOrExecute(std::move(node));
1090     if (!s.ok()) {
1091       result[0]->Unref();
1092     }
1093     return s;
1094 #endif  // !IS_MOBILE_PLATFORM
1095   }
1096 }
1097 }  // namespace tensorflow
1098