1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17
18 #include <cstddef>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/platform/platform.h"
30 // clang-format on
31
32 #include "absl/strings/match.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/compiler/jit/defs.h"
36 #include "tensorflow/core/common_runtime/device.h"
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/eager/context.h"
39 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
40 #include "tensorflow/core/common_runtime/eager/execute_node.h"
41 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
42 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
43 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
44 #include "tensorflow/core/common_runtime/colocation_graph.h"
45 #include "tensorflow/core/framework/dataset.h"
46 #include "tensorflow/core/framework/function.h"
47 #include "tensorflow/core/framework/logging.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/tensor_reference.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/util/device_name_utils.h"
54 #if !defined(IS_MOBILE_PLATFORM)
55 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
56 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
57 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
58 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
59 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
60 #endif // IS_MOBILE_PLATFORM
61 #include "tensorflow/core/framework/step_stats.pb.h"
62 #include "tensorflow/core/framework/tensor.h"
63 #include "tensorflow/core/framework/types.h"
64 #include "tensorflow/core/lib/core/status.h"
65 #include "tensorflow/core/lib/gtl/cleanup.h"
66 #include "tensorflow/core/lib/gtl/flatset.h"
67 #include "tensorflow/core/lib/gtl/inlined_vector.h"
68 #include "tensorflow/core/lib/random/random.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/mutex.h"
71 #include "tensorflow/core/util/ptr_util.h"
72 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
73
74 namespace tensorflow {
75
76 namespace {
77
DeviceNameOrUnspecified(Device * device)78 const string& DeviceNameOrUnspecified(Device* device) {
79 static string* unspecified_string = new string("<unspecified>");
80 return (device == nullptr) ? *unspecified_string : device->name();
81 }
82
83 // This function expects *handle to point to an existing tensor handle that is
84 // currently on "handle_device", but where the operation expects that input to
85 // reside on "expected_input_device". The function will arrange for this
86 // transfer to happen and will return OK on success and will storage a new
87 // handle to the equivalent tensor on the correct device in "*result". Or if an
88 // error is encountered, it will return a non-OK status and set "*result" to
89 // nullptr.
90 //
91 // `op_device` is passed in explicitly because `op->device()` might be
92 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)93 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
94 Device* op_device,
95 TensorHandle* handle, // op->Inputs()[i]
96 int i, Device* handle_device,
97 Device* expected_input_device,
98 TensorHandle** result) {
99 // Should only be called when these don't match
100 DCHECK(expected_input_device != handle_device);
101 *result = nullptr;
102 const string& op_device_name = DeviceNameOrUnspecified(op_device);
103
104 switch (ctx->GetDevicePlacementPolicy()) {
105 case DEVICE_PLACEMENT_SILENT_FOR_INT32:
106 // TODO(xpan): See if we could bubble python related error up
107 // to python level.
108 if (handle->dtype == DT_INT32) {
109 // Note: enabling silent copies of int32 tensors to match behavior
110 // of graph mode.
111 break;
112 }
113 TF_FALLTHROUGH_INTENDED;
114 case DEVICE_PLACEMENT_EXPLICIT:
115 return errors::InvalidArgument(
116 "Tensors on conflicting devices:"
117 " cannot compute ",
118 op->Name(), " as input #", i, " was expected to be on ",
119 expected_input_device->name(), " but is actually on ",
120 handle_device->name(), " (operation running on ", op_device_name, ")",
121 " Tensors can be copied explicitly using:"
122 " `with tf.device(device_name): x = tf.identity(x)`"
123 " or transparently copied by using"
124 " tf.config.experimental.set_device_policy('silent')."
125 " Copying tensors between devices may slow down your model");
126 case DEVICE_PLACEMENT_WARN:
127 LOG(WARNING) << "before computing " << op->Name() << " input #" << i
128 << " was expected to be on " << expected_input_device->name()
129 << " but is actually on " << handle_device->name()
130 << " (operation running on " << op_device_name
131 << "). This triggers a copy which can be a performance "
132 "bottleneck.";
133 break;
134 case DEVICE_PLACEMENT_SILENT: // Do nothing.
135 break;
136 }
137 // We are only here if the policy is warn or silent copies, so we should
138 // trigger a copy.
139 TensorHandle* result_handle = nullptr;
140 profiler::TraceMe activity(
141 [&] {
142 return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
143 " to ", expected_input_device->name());
144 },
145 profiler::TraceMeLevel::kInfo);
146 Status status =
147 EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
148 ctx->MirrorTensors(), &result_handle);
149 activity.Stop();
150 if (!status.ok()) {
151 return errors::Internal("Failed copying input tensor from ",
152 handle_device->name(), " to ",
153 expected_input_device->name(), " in order to run ",
154 op->Name(), ": ", status.error_message());
155 }
156
157 *result = result_handle;
158
159 return Status::OK();
160 }
161
162 // `op_device_name` the name of the device on which the op will run, if any.
163 // For functions running using function library runtime, the device can be
164 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)165 Status ValidateInputTypeAndPlacement(
166 EagerContext* ctx, EagerOperation* op,
167 const core::RefCountPtr<KernelAndDevice>& kernel) {
168 profiler::TraceMe activity("ValidateInputTypeAndPlacement",
169 profiler::TraceMeLevel::kInfo);
170 const int n_inputs = op->Inputs().size();
171 if (kernel->num_inputs() != n_inputs) {
172 return errors::InvalidArgument("expected ", kernel->num_inputs(),
173 " inputs, got ", n_inputs);
174 }
175 const bool skip_remote_copy =
176 ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction();
177 for (int i = 0; i < n_inputs; ++i) {
178 TensorHandle* handle = op->Inputs()[i];
179 Device* expected_device = kernel->InputDevice(i);
180 Device* handle_device = handle->DeviceOrHostCPU(*ctx);
181 const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
182 // If the input is already on the right device, then nothing to do.
183 if (expected_device != handle_device && maybe_copy) {
184 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
185 handle, i, handle_device,
186 expected_device, &handle));
187 op->UpdateInput(i, handle);
188 // Unref handle since it has a ref as an input now
189 handle->Unref();
190 }
191 if (handle->dtype != kernel->input_type(i)) {
192 return errors::InvalidArgument(
193 "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
194 " was expected to be a ", DataTypeString(kernel->input_type(i)),
195 " tensor but is a ", DataTypeString(handle->dtype), " tensor");
196 }
197 }
198 return Status::OK();
199 }
200
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)201 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
202 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
203 const OpDef* op_def = nullptr;
204
205 const FunctionDef* function_def =
206 op->EagerContext().FuncLibDef()->Find(op->Name());
207 if (function_def != nullptr) {
208 op_def = &(function_def->signature());
209 } else {
210 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
211 }
212
213 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
214
215 return Status::OK();
216 }
217
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)218 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
219 const tensorflow::Fprint128& b) {
220 return {tensorflow::FingerprintCat64(a.low64, b.low64),
221 tensorflow::FingerprintCat64(a.high64, b.high64)};
222 }
223
FingerprintCat128(const tensorflow::Fprint128 & a,const int64 b)224 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
225 const int64 b) {
226 auto x = tensorflow::FingerprintCat64(a.low64, b);
227 return {x, tensorflow::FingerprintCat64(a.high64, x)};
228 }
229
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)230 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
231 Device** result) {
232 Device* cpu_device = ctx.HostCPU();
233 string device_name;
234 if (tensor_handle->IsRemote()) {
235 Device* device = tensor_handle->device();
236 device_name = device != nullptr ? device->name() : cpu_device->name();
237 *result = (device == nullptr ? cpu_device : device);
238 } else if (tensor_handle->dtype == DT_RESOURCE) {
239 // Use the resource's actual device because it is the device that will
240 // influence partitioning the multi-device function.
241 const Tensor* tensor;
242 // TODO(fishx): Avoid blocking here.
243 TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
244 const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
245 device_name = handle.device();
246
247 Device* input_device;
248 TF_RETURN_IF_ERROR(
249 ctx.FindDeviceFromName(device_name.c_str(), &input_device));
250 *result = input_device;
251 } else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
252 *result = cpu_device;
253 } else {
254 Device* device = tensor_handle->device();
255 device_name = device != nullptr ? device->name() : cpu_device->name();
256 *result = (device == nullptr ? cpu_device : device);
257 }
258 return Status::OK();
259 }
260
261 // Appends a TensorShape object to Fprint128 hash.
262 // For best performance, we would like to avoid dynamic memory allocation in
263 // this function.
264 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
265 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)266 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
267 Fprint128* fingerprint) {
268 if (shape.unknown_rank()) {
269 char c = '?';
270 *fingerprint = FingerprintCat128(*fingerprint, c);
271 } else {
272 for (int i = 0; i < shape.dims(); i++) {
273 int64 dim = shape.dim_size(i);
274 *fingerprint = FingerprintCat128(*fingerprint, dim);
275 }
276 }
277 }
278
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)279 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
280 bool* compile_with_xla) {
281 if (!op->is_function()) {
282 *compile_with_xla = false;
283 return Status::OK();
284 }
285
286 if (op->remote_func_params().has_value() &&
287 op->remote_func_params().value().step_id.has_value()) {
288 // If the op is a component of a multi-device function, don't compile it
289 // with XLA.
290 *compile_with_xla = false;
291 return Status::OK();
292 }
293
294 // Does node have an explicit request to compile or not?
295 Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla);
296 if (status.ok()) {
297 DVLOG(2) << "Caller explicitly requested "
298 << (*compile_with_xla ? "" : "not ")
299 << "to compile with XLA: " << op->DebugString();
300 return Status::OK();
301 }
302
303 // Does FunctionDef have an explicit request to compile or not?
304 const FunctionDef* function_def =
305 ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
306 if (function_def == nullptr) {
307 return errors::NotFound("Failed to find function '", op->Name(), "'");
308 }
309
310 status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr,
311 compile_with_xla);
312 if (status.ok()) {
313 DVLOG(2) << "Function definition explicitly specifies "
314 << (*compile_with_xla ? "" : "not ") << "to compile with XLA";
315 return Status::OK();
316 }
317
318 // No explicit requests. Compile for XLA devices by default.
319 if (op->GetDeviceParsedName().type == "TPU" ||
320 op->GetDeviceParsedName().type == "XLA_GPU" ||
321 op->GetDeviceParsedName().type == "XLA_CPU") {
322 DVLOG(2) << "Compiling " << op->Name()
323 << " with XLA because it is running on an XLA device "
324 << op->GetDeviceParsedName().type;
325 *compile_with_xla = true;
326 } else {
327 *compile_with_xla = false;
328 }
329
330 return Status::OK();
331 }
332
333 // There are a lot of references to devices in this function and around.
334 // Here is what they mean:
335 // EagerOperation::Device(): The device on which the user requested the op
336 // be executed, except if we had to change the device due to resource inputs
337 // or CPU pinning. If the user did not request a device, the op does not
338 // take resources, and we did not pin it to CPU, the device can be nullptr.
339 // KernelAndDevice::Device(): The first time we see an op (combined with
340 // its attributes), we need to create a KernelAndDevice object for it.
341 // If op->Device() is a nullptr, we select a device for the op when
342 // creating the KernelAndDevice. A concrete device will always be selected
343 // here except when `op` is a function to be executed using function library
344 // runtime. In this case, we don't select a device because running
345 // a function with explicitly requested device has different behavior than
346 // running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)347 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
348 int* num_retvals) {
349 MEMDEBUG_CACHE_OP(op->op_name());
350 profiler::TraceMe activity(
351 [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
352 profiler::TraceMeLevel::kInfo);
353 EagerContext& ctx = op->EagerContext();
354 auto& executor = op->Executor();
355 TF_RETURN_IF_ERROR(executor.status());
356 Device* device = op->Device();
357
358 Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName());
359
360 std::vector<Device*> input_dev_ptrs;
361 std::unordered_map<int, DtypeAndPartialTensorShape>
362 input_resource_variable_dtypes_and_shapes;
363 // We can eliminate some overhead by running simple functions using regular
364 // CallOp kernel. However, it is tricky to figure out which functions should
365 // be run using CallOp. Also, currently CallOp runs neither optimization
366 // passes (needed for TPU/XLA) nor grappler.
367 // Here are some cases where a function should be run in multi-device mode:
368 // - Function takes at least two resources on different devices.
369 // - Function takes a resource on deviceA and a body op explicitly placed
370 // on deviceB.
371 // - Function has a colocation constraint.
372 // - Function has an explicit device annotation (which might not be using
373 // full canonical device name) different from op_device. Note that false
374 // positives are ok.
375 // - Function has a node or a (node) attribute that can potentially make
376 // the function multi-device after a rewrite pass (e.g. various XLA/TPU
377 // special nodes and attributes)
378 if (op->is_function()) {
379 profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
380 profiler::TraceMeLevel::kInfo);
381 input_dev_ptrs.reserve(op->Inputs().size());
382 // When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on
383 // local devices, since we execute a remote function through worker service,
384 // which doesn't accept remote inputs.
385 for (int i = 0; i < op->Inputs().size(); i++) {
386 TensorHandle* input = op->Inputs()[i];
387 if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
388 TensorHandle* handle = nullptr;
389 TF_RETURN_IF_ERROR(EagerCopyToDevice(
390 input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device,
391 ctx.MirrorTensors(), &handle));
392 op->UpdateInput(i, handle);
393 // Unref handle since it has a ref as an input now
394 handle->Unref();
395 input = handle;
396 }
397
398 // Get device for this input, and add it to 'cache_key'.
399 Device* input_device;
400 TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
401 input_dev_ptrs.push_back(input_device);
402 cache_key =
403 FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
404
405 // If input is a ResourceHandle, get its resource handle dtypes and shapes
406 // and add them to 'cache_key'.
407 if (input->dtype == DT_RESOURCE) {
408 // We only care about data type and shape for resource variable inputs.
409 // But we have no way to tell if input is resource variable (other than
410 // looking it up in ResourceMgr, which is slow). So we just get
411 // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
412 // resource_dtypes_and_shapes is not empty, take the first element.
413 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
414 TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
415 &resource_dtypes_and_shapes));
416 if (!resource_dtypes_and_shapes.empty()) {
417 const DtypeAndPartialTensorShape& dtype_and_shape =
418 resource_dtypes_and_shapes.at(0);
419 input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
420
421 // Add _Arg index, dtype and shape to "cache_key".
422 cache_key = FingerprintCat128(cache_key, i);
423 DataType dtype = dtype_and_shape.dtype;
424 cache_key = FingerprintCat128(cache_key, dtype);
425 AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
426 }
427 }
428 }
429 }
430
431 core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
432 if (kernel == nullptr) {
433 DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
434 << DeviceNameOrUnspecified(op->Device());
435 bool run_function_with_flr = false;
436 if (op->is_function()) {
437 bool compile_with_xla;
438 TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
439 if (compile_with_xla) {
440 // Note that it is not ideal, but currently correct, to set this
441 // attribute after computing the kernel cache key above.
442 // Note: If the attribute is already set to true, this is a noop.
443 op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
444 } else {
445 run_function_with_flr = true;
446 }
447 }
448
449 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
450 if (device == nullptr) {
451 PrioritizedDeviceTypeVector supported_devs;
452 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
453 ctx.prioritized_device_type_list(), ndef, &supported_devs,
454 &ctx.HostCPU()->parsed_name()));
455 if (supported_devs.empty()) {
456 return errors::NotFound("Could not find valid device for node.\nNode:",
457 FormatNodeDefForError(ndef),
458 "\nAll kernels registered for op ", ndef.op(),
459 " :\n", KernelsRegisteredForOp(ndef.op()));
460 }
461 TF_RETURN_IF_ERROR(ctx.SelectDevice(op->GetDeviceParsedName(),
462 supported_devs, DT_INVALID, &device));
463
464 DVLOG(1) << "Placer place op [" << op->Name()
465 << "] on device: " << device->name();
466 DVLOG(4) << "Available kernels for " << op->Name() << "are "
467 << KernelsRegisteredForOp(op->Name());
468 op->SetDevice(device);
469 }
470 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
471 string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
472 DeviceNameOrUnspecified(device));
473 if (!logging::LogToListeners(msg)) {
474 LOG(INFO) << msg;
475 }
476 }
477
478 FunctionLibraryRuntime* flr =
479 device == nullptr ? nullptr : ctx.func_lib(device);
480 if (device != nullptr && flr == nullptr) {
481 return errors::Unavailable(
482 "Unable to find a FunctionLibraryRuntime corresponding to device ",
483 device->name());
484 }
485 auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
486 : ctx.runner();
487 GraphCollector* graph_collector = nullptr;
488 if (ctx.ShouldStoreGraphs()) {
489 graph_collector = ctx.GetGraphCollector();
490 }
491 // Treat the function as multi_device only when we are not compiling
492 // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
493 // will create an XlaLaunchOp kernel to compile and run the function.
494 if (run_function_with_flr) {
495 // Multi-device functions don't use the rendezvous from eager context.
496 // If we use that rendezvous, multiple concurrent calls to the same
497 // function will likely result in collisions. However, this also means
498 // that we don't support legitimate sending/receiving across function
499 // boundary.
500 DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
501 << "Full node_def=" << ndef.DebugString();
502 std::function<int64()> get_op_id = nullptr;
503 #if !defined(IS_MOBILE_PLATFORM)
504 if (ctx.LazyCopyFunctionRemoteInputs()) {
505 get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
506 }
507 #endif // IS_MOBILE_PLATFORM
508 kernel.reset(new KernelAndDeviceFunc(
509 flr, ctx.pflr(), std::move(input_dev_ptrs),
510 std::move(input_resource_variable_dtypes_and_shapes), runner,
511 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
512 [&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },
513 get_op_id));
514 } else {
515 DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
516 << ". Full node_def=" << ndef.DebugString();
517 kernel.reset(new KernelAndDeviceOp(
518 ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
519 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
520 }
521
522 TF_RETURN_IF_ERROR(kernel->Init(ndef, graph_collector));
523
524 if (op->is_function()) {
525 ctx.AddKernelToCache(cache_key, kernel.get());
526 } else {
527 // Exclude tf.data op kernels from being cached. The reason for this is
528 // that tf.data op kernels that accept a user-defined function will have a
529 // unique cache key every time they are executed (because the user-defined
530 // function is traced every time). Caching such kernels provides no
531 // benefit and in some cases results in linear memory growth of use
532 // programs that build input pipeline graphs in a loop.
533 const OpDef* op_def;
534 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
535 if (!data::DatasetOpKernel::IsDatasetOp(op_def)) {
536 ctx.AddKernelToCache(cache_key, kernel.get());
537 }
538 }
539 }
540 const DataTypeVector& output_dtypes = kernel->output_dtypes();
541 const size_t num_outputs = static_cast<int>(output_dtypes.size());
542 if (num_outputs > *num_retvals) {
543 return errors::InvalidArgument("Expecting ", num_outputs,
544 " outputs, but *num_retvals is ",
545 *num_retvals);
546 }
547 *num_retvals = num_outputs;
548 TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
549
550 GraphCollector* graph_collector = nullptr;
551 if (ctx.ShouldStoreGraphs()) {
552 graph_collector = ctx.GetGraphCollector();
553 }
554
555 const bool async = executor.Async();
556 for (int i = 0; i < num_outputs; ++i) {
557 TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
558 async,
559 /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
560 /* op_device= */ kernel->device(),
561 /* resource_device= */ kernel->OutputResourceDevice(i),
562 output_dtypes[i], &ctx, &retvals[i]));
563 }
564
565 Status s;
566 if (async) {
567 auto node = absl::make_unique<ExecuteNode>(
568 &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
569 graph_collector, output_dtypes, op->GetCancellationManager(),
570 executor.Async(), absl::Span<TensorHandle*>(retvals, num_outputs));
571 // For async mode, execution order will make sure that all
572 // input handles are ready before executing them.
573 // TODO(b/137118203): Consider executing "cheap" kernels inline for
574 // performance.
575 s = executor.AddOrExecute(std::move(node));
576 } else {
577 ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(),
578 std::move(kernel), graph_collector, output_dtypes,
579 op->GetCancellationManager(), executor.Async(),
580 {retvals, num_outputs});
581 s = executor.SyncExecute(&node);
582 }
583 // Since the operation failed, we need to Unref any outputs that were
584 // allocated.
585 if (!s.ok()) {
586 for (int i = 0; i < num_outputs; ++i) {
587 retvals[i]->Unref();
588 }
589 }
590
591 return s;
592 }
593
594 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)595 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
596 EagerContext& ctx = op->EagerContext();
597
598 remote_op->set_id(ctx.RemoteMgr()->NextOpId());
599 remote_op->set_name(op->Name());
600
601 op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
602 remote_op->set_device(op->Device()->name());
603 remote_op->set_is_function(op->is_function());
604 }
605
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)606 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
607 const DataTypeVector& output_dtypes,
608 TensorHandle** retvals) {
609 if (remote_op.name() == "VarHandleOp") {
610 if (output_dtypes.size() != 1) {
611 return errors::Internal("VarHandleOp should only have one output.");
612 }
613 if (output_dtypes[0] != DT_RESOURCE) {
614 return errors::Internal(
615 "The output of VarHandleOp should be a DT_RESOURCE.");
616 }
617 AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
618 const AttrValue* dtype;
619 TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
620 const AttrValue* shape;
621 TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
622 retvals[0]->SetResourceHandleDtypeAndShape(
623 {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
624 }
625 return Status::OK();
626 }
627
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)628 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
629 int* num_retvals) {
630 EagerContext& ctx = op->EagerContext();
631
632 // TODO(fishx): Remove following code when lazy tensor copy is ready.
633 if (op->Device() == nullptr) {
634 tensorflow::Device* device = nullptr;
635 string device_name = op->GetDeviceName();
636 TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
637 op->SetDevice(device);
638 }
639
640 core::RefCountPtr<eager::EagerClient> eager_client;
641 uint64 context_id = ctx.GetContextId();
642 TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
643 string remote_task;
644 if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
645 return errors::InvalidArgument(
646 "Unable to find remote task corresponding to device ",
647 op->Device()->name());
648 }
649
650 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
651 request->set_context_id(context_id);
652
653 eager::Operation* remote_op = request->add_queue()->mutable_operation();
654
655 {
656 profiler::TraceMe activity("CopyInputToExpectedDevice",
657 profiler::TraceMeLevel::kInfo);
658 const bool eagerly_copy_function_remote_inputs =
659 !ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
660 for (int i = 0; i < op->Inputs().size(); i++) {
661 tensorflow::TensorHandle* input = op->Inputs()[i];
662 tensorflow::Device* input_device = input->device();
663 const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
664 bool serialize_resource_dtype_and_shape = false;
665 if (op->Device() != input_device &&
666 // If the expected and actual devices are on the same task, don't
667 // explicitly copy, and instead depend on the copy to happen locally
668 // when the op is executed on the device.
669 !ctx.OnSameTask(op->Device(), input_device)) {
670 if (eagerly_copy_function_remote_inputs ||
671 input->DeviceOrHostCPU(ctx)->IsLocal()) {
672 tensorflow::Device* remote_cpu_device;
673 TF_RETURN_IF_ERROR(
674 ctx.CPUDeviceOnTask(op->Device(), &remote_cpu_device));
675 // TODO(b/110044833): It's possible the same tensor gets copied to the
676 // remote device repeatedly.
677 // Always copy to the remote CPU so that the actual device can be
678 // correctly determined after the kernel is selected/instantiated,
679 // since the op might have its inputs on host memory.
680 TensorHandle* handle = op->Inputs()[i];
681 Device* handle_device = handle->DeviceOrHostCPU(ctx);
682 // If the input is already on the right device, then nothing to do.
683 if (remote_cpu_device != handle_device) {
684 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
685 &ctx, op, op->Device(), handle, i, handle_device,
686 remote_cpu_device, &handle));
687 op->UpdateInput(i, handle);
688 input = handle;
689 input_device = remote_cpu_device;
690 input_device_name = &remote_cpu_device->name();
691 // Unref handle since it has a ref as an input now
692 handle->Unref();
693 }
694 } else {
695 serialize_resource_dtype_and_shape =
696 (input->dtype == DT_RESOURCE) &&
697 (!input->HasResourceShapeMirror(op->Device()));
698 }
699 }
700 auto* input_handle = remote_op->add_inputs();
701 TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
702 input, input_handle, input_device, *input_device_name,
703 serialize_resource_dtype_and_shape));
704 if (!input_handle->resource_dtypes_and_shapes().empty()) {
705 auto tensor_handle_data =
706 absl::make_unique<UnshapedRemoteTensorHandleData>(
707 input_handle->op_id(), input_handle->output_num(), remote_task,
708 context_id, &ctx);
709 TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
710 std::move(tensor_handle_data), op->Device()));
711 }
712 }
713 }
714
715 PrepareRemoteOp(remote_op, op);
716
717 DataTypeVector output_dtypes;
718 TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
719
720 const size_t num_outputs = static_cast<int>(output_dtypes.size());
721 if (num_outputs != *num_retvals) {
722 return errors::InvalidArgument(
723 "num_retvals does not match expected output dtypes");
724 }
725 *num_retvals = num_outputs;
726
727 tensorflow::Device* op_device = op->Device();
728 const tensorflow::uint64 id = remote_op->id();
729 for (int i = 0; i < num_outputs; ++i) {
730 // TODO(nareshmodi): Change the callback to instead add the decref to a
731 // list of pending decrefs that we can send as a batch with the next
732 // execute.
733
734 // The device_ and resource_device_ of this TensorHandle might be
735 // incorrect. It is pretty hard to make it correct because for
736 // multi-device functions, we don't know the output device until the
737 // function is instantiated. Luckily, we don't need to know the correct
738 // remote device here. We just need to know that it is remote. If we need
739 // to copy this tensor to this process, the remote end will know the
740 // correct device of this handle.
741 Status status = TensorHandle::CreateUnshapedRemoteHandle(
742 id, i, remote_task, context_id, output_dtypes[i], op_device, &ctx,
743 &retvals[i]);
744 if (!status.ok()) {
745 for (int j = 0; j < i; ++j) {
746 retvals[j]->Poison(errors::Internal(
747 "Failed to construct unshaped remote tensor handle at index ", i,
748 " for op ", op->Name()));
749 }
750 return status;
751 }
752 }
753
754 if (ctx.LazyCopyFunctionRemoteInputs()) {
755 // Store the data type and shape of a remote resource variable on the
756 // corresponding remote TensorHandle (output of 'VarHandleOp').
757 // If the variable is an input of a remote function, the function may need
758 // the type and shape during function instantiation. When
759 // LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
760 // handle (contains the type and shape) of the variable to the default
761 // function device. Instead, we store the type and shape on eager master
762 // and sent them to the default function device along with the
763 // EnqueueRequest.
764 TF_RETURN_IF_ERROR(
765 StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
766 }
767
768 auto& executor = op->Executor();
769 DVLOG(4) << "Execute remote eager op: " << op->Name()
770 << " (is async?: " << executor.Async() << ").";
771
772 std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
773 std::move(request), op_device, eager_client.get(),
774 op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
775 op->Inputs(), {retvals, num_outputs}));
776 Status s = executor.AddOrExecute(std::move(node));
777 // Since the operation failed, we need to Unref any outputs that were
778 // allocated.
779 if (!s.ok()) {
780 for (int i = 0; i < num_outputs; ++i) {
781 retvals[i]->Unref();
782 }
783 }
784
785 return s;
786 }
787 #endif // IS_MOBILE_PLATFORM
788
789 // These ops are not pinnable since they generate data. It can be slower to
790 // generate and then copy the data instead of just generating the data on the
791 // device directly.
IsPinnableOp(const string & op_type)792 bool IsPinnableOp(const string& op_type) {
793 static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
794 "RandomUniform",
795 "RandomUniformInt",
796 "RandomStandardNormal",
797 "StatelessRandomUniform",
798 "StatelessRandomUniformInt",
799 "StatelessRandomNormal",
800 });
801
802 // XRT ops refer to per-device handles that are not safe to move between
803 // devices.
804 return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
805 !absl::StartsWith(op_type, "XRT");
806 }
807
808 // The Op device may be updated if:
809 // - A resource touching input is specified: all resource-touching ops run in
810 // the device the resource is, regardless of anything else that has been
811 // specified. This is identical to the graph mode behavior.
812 //
813 // - All op inputs are on the CPU, small (<64 elements) and integers
814 // (int32/int64). This can be disabled by setting the environment variable
815 // "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
MaybeUpdateOpDevice(EagerOperation * op)816 Status MaybeUpdateOpDevice(EagerOperation* op) {
817 const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
818 if (op->is_function() || exempt_ops.find(op->Name()) != exempt_ops.end()) {
819 // Don't update the device of direct function calls.
820 // Particularly, if the user did not explicitly request any device for this
821 // function, picking a device would result in this device being the default
822 // for nodes inside the function. This is undesirable for multi-device
823 // functions since the not-explicitly-placed nodes inside the body will all
824 // end up on this default device.
825 return Status::OK();
826 }
827 EagerContext& ctx = op->EagerContext();
828 bool all_inputs_eligible_for_cpu_pinning =
829 ctx.PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
830 Device* op_device = op->Device() == nullptr ? ctx.HostCPU() : op->Device();
831 for (int i = 0; i < op->Inputs().size(); ++i) {
832 TensorHandle* tensor_handle = op->Inputs()[i];
833 if (tensor_handle->dtype == DT_RESOURCE) {
834 Device* resource_device = tensor_handle->resource_device();
835 DVLOG(2) << "for op " << op->Name() << " input " << i << " "
836 << DataTypeString(tensor_handle->dtype)
837 << " input device = " << resource_device->name()
838 << ", op device = " << op_device->name();
839 // We check for `op->Device() == nullptr` because it can be later
840 // interpreted as unspecified device and a different device can
841 // be selected based on device priority. If any input to an op
842 // is a resource we must pin it to prevent different device selection.
843 // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
844 if (resource_device != op_device || op->Device() == nullptr) {
845 DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
846 << "device of operation " << op->Name() << " to "
847 << resource_device->name() << " because input #" << i
848 << " is a resource in this device.";
849 op->SetDevice(resource_device);
850 }
851 all_inputs_eligible_for_cpu_pinning = false;
852 // No point in looking at other inputs. If there are other resources,
853 // they must have the same device and we already declared the op to be
854 // ineligible for CPU pinning.
855 break;
856 } else if (all_inputs_eligible_for_cpu_pinning) {
857 Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
858 DVLOG(2) << "for op " << op->Name() << " input " << i << " "
859 << DataTypeString(tensor_handle->dtype)
860 << " input device = " << input_device->name()
861 << ", op device = " << op_device->name();
862
863 // Input is on CPU.
864 if (input_device != ctx.HostCPU()) {
865 all_inputs_eligible_for_cpu_pinning = false;
866 continue;
867 }
868
869 if (tensor_handle->dtype != DataType::DT_INT32 &&
870 tensor_handle->dtype != DataType::DT_INT64) {
871 all_inputs_eligible_for_cpu_pinning = false;
872 continue;
873 }
874
875 int64 num_elements;
876 TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
877 if (num_elements > 64) {
878 all_inputs_eligible_for_cpu_pinning = false;
879 }
880 }
881 }
882
883 // Ops without inputs are usually ops that generate a tensor in some way and
884 // usually require being present on whatever device they are scheduled on
885 // - for e.g. VarHandleOp or _Recv).
886 // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
887 // an op, but there is a GPU kernel?
888 if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
889 DVLOG(1) << "Forcing op " << op->Name()
890 << " to be on the CPU since all input tensors have an "
891 "int32/int64 dtype, and are small (less than 64 elements).";
892 op->SetDevice(ctx.HostCPU());
893 }
894
895 return Status::OK();
896 }
897 } // namespace
898
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)899 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
900 int* num_retvals) {
901 profiler::TraceMe activity(
902 [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
903 profiler::TraceMeLevel::kInfo);
904 TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
905
906 if (!op->Executor().Async()) {
907 // In sync mode, always clear error to maintain the same behavior as before.
908 // TODO(b/141004939): Remove this.
909 op->Executor().ClearError();
910 }
911
912 std::unique_ptr<tensorflow::EagerOperation> out_op;
913 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
914 EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
915
916 if (op->IsLocal()) {
917 if (out_op) {
918 op = out_op.get();
919 }
920 return EagerLocalExecute(op, retvals, num_retvals);
921 }
922
923 if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
924 string msg = strings::StrCat(
925 "Executing op ", op->Name(), " on task ",
926 DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
927 if (!logging::LogToListeners(msg)) {
928 LOG(INFO) << msg;
929 }
930 }
931
932 #if defined(IS_MOBILE_PLATFORM)
933 return errors::Unimplemented(
934 "Eager's remote execution is not available on mobile devices.");
935 #else // !IS_MOBILE_PLATFORM
936 if (out_op) {
937 op = out_op.get();
938 }
939 return EagerRemoteExecute(op, retvals, num_retvals);
940 #endif // !IS_MOBILE_PLATFORM
941 }
942
943 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals)944 Status EagerKernelExecute(
945 EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
946 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
947 const core::RefCountPtr<KernelAndDevice>& kernel,
948 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
949 absl::Span<TensorHandle*> retvals) {
950 profiler::TraceMe activity("EagerKernelExecute",
951 profiler::TraceMeLevel::kInfo);
952 std::vector<Tensor> outputs(1);
953
954 ExecuteNodeArgs inputs(op_inputs.size());
955 TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs));
956 // TODO(apassos) figure out how to record stats for ops which are a part of
957 // functions.
958 // TODO(b/111859745): When we support recovering from kernel/device errors, we
959 // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
960 // device. We don't call it now because it is an unneeded overhead (it
961 // acquires a lock) and we can't recover from errors anyway.
962 ScopedStepContainer* container = ctx->StepContainer();
963 if (container == nullptr) {
964 TF_RETURN_IF_ERROR(kernel->Run(inputs, &outputs, cancellation_manager,
965 remote_func_params));
966 } else {
967 TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
968 cancellation_manager, remote_func_params));
969 }
970 if (graph_collector != nullptr) {
971 mutex_lock ml(*ctx->MetadataMu());
972 {
973 GraphCollector* collector = ctx->GetGraphCollector();
974 mutex_lock mll(collector->mu);
975
976 // Adding to partition graphs for backward compatibility.
977 for (const auto& graph : collector->partitioned_graphs) {
978 *ctx->RunMetadataProto()->add_partition_graphs() = graph;
979 }
980
981 if (collector->dirty) {
982 auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
983 *function_graphs->mutable_post_optimization_graph() =
984 collector->optimized_graph;
985 *function_graphs->mutable_pre_optimization_graph() =
986 collector->raw_graph;
987 for (const auto& graph : collector->partitioned_graphs) {
988 *function_graphs->add_partition_graphs() = graph;
989 }
990 }
991
992 collector->ClearGraphs();
993 }
994 }
995 DCHECK_EQ(retvals.size(), outputs.size());
996 for (int i = 0; i < retvals.size(); ++i) {
997 DCHECK_EQ(kernel->device(), retvals[i]->op_device());
998 DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
999 retvals[i]->device());
1000
1001 TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
1002 }
1003 return Status::OK();
1004 }
1005
1006 namespace {
1007
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,TensorHandle ** result)1008 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1009 EagerExecutor* executor, Device* dstd,
1010 TensorHandle** result) {
1011 TF_RETURN_IF_ERROR(executor->status());
1012 TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
1013 true, ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype,
1014 ctx, result));
1015
1016 // Note that `h` may not be currently ready. However execution order will
1017 // make sure that `h` is ready before the copy is actually done.
1018 std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, *ctx));
1019 Status s = executor->AddOrExecute(std::move(node));
1020 // Since the operation failed, we need to Unref any outputs that were
1021 // allocated.
1022 if (!s.ok()) {
1023 (*result)->Unref();
1024 }
1025
1026 return s;
1027 }
1028
1029 } // namespace
1030
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1031 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1032 EagerExecutor* executor, Device* device, bool mirror,
1033 TensorHandle** result) {
1034 Device* send_device = h->DeviceOrHostCPU(*ctx);
1035
1036 bool sender_is_local = send_device->IsLocal();
1037
1038 bool recver_is_local = device->IsLocal();
1039
1040 if (!executor->Async()) {
1041 // In sync mode, always clear error to maintain the same behavior as before.
1042 // TODO(b/141004939): Remove this.
1043 executor->ClearError();
1044 }
1045
1046 if (sender_is_local && recver_is_local) {
1047 return LocalEagerCopyToDevice(h, ctx, executor, device, result);
1048 } else {
1049 #if defined(IS_MOBILE_PLATFORM)
1050 return errors::Unimplemented(
1051 "Eager's remote execution is not available on mobile devices.");
1052 #else // !IS_MOBILE_PLATFORM
1053 if (mirror) {
1054 if (h->HasRemoteMirror(device)) {
1055 h->Ref();
1056 *result = h;
1057 return Status::OK();
1058 }
1059 }
1060 uint64 recv_op_id = 0;
1061 if (recver_is_local) {
1062 TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
1063 true, /* d= */ device, /* op_device= */ device,
1064 /*resource_device=*/nullptr, h->dtype, ctx, result));
1065 } else {
1066 uint64 context_id = ctx->GetContextId();
1067 string remote_task;
1068 if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1069 return errors::InvalidArgument(
1070 "Unable to find remote task corresponding to device ",
1071 device->name());
1072 }
1073 recv_op_id = ctx->RemoteMgr()->NextOpId();
1074 auto tensor_handle_data =
1075 absl::make_unique<UnshapedRemoteTensorHandleData>(
1076 recv_op_id, 0, remote_task, context_id, ctx);
1077 if (mirror) {
1078 TF_RETURN_IF_ERROR(
1079 h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device));
1080 h->Ref();
1081 *result = h;
1082 } else {
1083 TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
1084 std::move(tensor_handle_data), h->dtype, device, ctx, result));
1085 }
1086 }
1087 auto node = absl::make_unique<eager::RemoteCopyNode>(
1088 ctx, executor, h, result[0], device, recv_op_id);
1089 Status s = executor->AddOrExecute(std::move(node));
1090 if (!s.ok()) {
1091 result[0]->Unref();
1092 }
1093 return s;
1094 #endif // !IS_MOBILE_PLATFORM
1095 }
1096 }
1097 } // namespace tensorflow
1098