• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17 
18 #include <cstddef>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/platform/platform.h"
30 // clang-format on
31 
32 #include "absl/container/inlined_vector.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/compiler/jit/defs.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_set.h"
40 #include "tensorflow/core/common_runtime/eager/context.h"
41 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
42 #include "tensorflow/core/common_runtime/eager/execute_node.h"
43 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
44 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
45 #include "tensorflow/core/common_runtime/colocation_graph.h"
46 #include "tensorflow/core/framework/dataset.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/logging.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/tensor_reference.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/error_codes.pb.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #if !defined(IS_MOBILE_PLATFORM)
57 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
58 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
59 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
60 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
61 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
62 #endif  // IS_MOBILE_PLATFORM
63 #include "tensorflow/core/framework/step_stats.pb.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/types.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/lib/gtl/cleanup.h"
68 #include "tensorflow/core/lib/gtl/flatset.h"
69 #include "tensorflow/core/lib/random/random.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/mutex.h"
72 #include "tensorflow/core/util/ptr_util.h"
73 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
74 
75 namespace tensorflow {
76 
77 namespace {
78 
DeviceNameOrUnspecified(Device * device)79 const string& DeviceNameOrUnspecified(Device* device) {
80   static string* unspecified_string = new string("<unspecified>");
81   return (device == nullptr) ? *unspecified_string : device->name();
82 }
83 
84 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)85 bool KernelCacheEnabled(const OpDef& op_def) {
86   if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
87     return false;
88   }
89   // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
90   // 5+ kernels to exclude.
91   return true;
92 }
93 
94 // This function expects *handle to point to an existing tensor handle that is
95 // currently on "handle_device", but where the operation expects that input to
96 // reside on "expected_input_device".  The function will arrange for this
97 // transfer to happen and will return OK on success and will storage a new
98 // handle to the equivalent tensor on the correct device in "*result".  Or if an
99 // error is encountered, it will return a non-OK status and set "*result" to
100 // nullptr.
101 //
102 // `op_device` is passed in explicitly because `op->device()` might be
103 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)104 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
105                                  Device* op_device,
106                                  TensorHandle* handle,  // op->Inputs()[i]
107                                  int i, Device* handle_device,
108                                  Device* expected_input_device,
109                                  TensorHandle** result) {
110   // Should only be called when these don't match
111   DCHECK(expected_input_device != handle_device);
112   *result = nullptr;
113   const string& op_device_name = DeviceNameOrUnspecified(op_device);
114 
115   switch (ctx->GetDevicePlacementPolicy()) {
116     case DEVICE_PLACEMENT_SILENT_FOR_INT32:
117       // TODO(xpan): See if we could bubble python related error up
118       // to python level.
119       if (handle->dtype == DT_INT32) {
120         // Note: enabling silent copies of int32 tensors to match behavior
121         // of graph mode.
122         break;
123       }
124       TF_FALLTHROUGH_INTENDED;
125     case DEVICE_PLACEMENT_EXPLICIT:
126       // tf.identity is allowed to copy, as indicated in the error message
127       // below.
128       if (op->Name() == "Identity" || op->Name() == "IdentityN") {
129         break;
130       }
131       return errors::InvalidArgument(
132           "Tensors on conflicting devices:"
133           " cannot compute ",
134           op->Name(), " as input #", i, " was expected to be on ",
135           expected_input_device->name(), " but is actually on ",
136           handle_device->name(), " (operation running on ", op_device_name, ")",
137           " Tensors can be copied explicitly using:"
138           " `with tf.device(device_name): x = tf.identity(x)`"
139           " or transparently copied by using"
140           " tf.config.experimental.set_device_policy('silent')."
141           " Copying tensors between devices may slow down your model");
142     case DEVICE_PLACEMENT_WARN:
143       LOG(WARNING) << "before computing " << op->Name() << " input #" << i
144                    << " was expected to be on " << expected_input_device->name()
145                    << " but is actually on " << handle_device->name()
146                    << " (operation running on " << op_device_name
147                    << "). This triggers a copy which can be a performance "
148                       "bottleneck.";
149       break;
150     case DEVICE_PLACEMENT_SILENT:  // Do nothing.
151       break;
152   }
153   // We are only here if the policy is warn or silent copies, so we should
154   // trigger a copy.
155   TensorHandle* result_handle = nullptr;
156   profiler::TraceMe activity(
157       [&] {
158         return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
159                             " to ", expected_input_device->name());
160       },
161       profiler::TraceMeLevel::kInfo);
162   Status status =
163       EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
164                         /* mirror= */ true, &result_handle);
165   activity.Stop();
166   if (!status.ok()) {
167     return Status(
168         status.code(),
169         absl::StrCat("Failed copying input tensor from ", handle_device->name(),
170                      " to ", expected_input_device->name(), " in order to run ",
171                      op->Name(), ": ", status.error_message()));
172   }
173 
174   *result = result_handle;
175 
176   return Status::OK();
177 }
178 
179 // `op_device_name` the name of the device on which the op will run, if any.
180 // For functions running using function library runtime, the device can be
181 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)182 Status ValidateInputTypeAndPlacement(
183     EagerContext* ctx, EagerOperation* op,
184     const core::RefCountPtr<KernelAndDevice>& kernel) {
185   profiler::TraceMe activity("ValidateInputTypeAndPlacement",
186                              profiler::TraceMeLevel::kInfo);
187   const int n_inputs = op->Inputs().size();
188   if (kernel->num_inputs() != n_inputs) {
189     return errors::InvalidArgument("expected ", kernel->num_inputs(),
190                                    " inputs, got ", n_inputs);
191   }
192   const bool is_function = kernel->IsFunction();
193   if (n_inputs > 0) {
194     const DataType* input_types = &kernel->input_dtypes()[0];
195     const absl::InlinedVector<TensorHandle*, 4>* handles;
196     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
197     for (int i = 0; i < n_inputs; ++i) {
198       TensorHandle* handle = (*handles)[i];
199       Device* expected_device = kernel->InputDevice(i);
200       if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
201         // Extract a handle on the op device from a packed input.
202         // This happens when a function is marked for XLA compilation.
203         // MaybePackInputTensor guarantees that a primitive op has no packed
204         // input at this point.
205         for (int j = 0; j < handle->NumPackedHandles(); ++j) {
206           TensorHandle* h = nullptr;
207           TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
208           if ((h->op_device() != nullptr) &&
209               (h->op_device()->name() == op->DeviceName())) {
210             op->UpdateInput(i, h);
211             handle = h;
212             break;
213           }
214         }
215       }
216       Device* handle_device = handle->DeviceOrHostCPU(*ctx);
217       const bool maybe_copy =
218           !is_function || handle->Type() != TensorHandle::REMOTE;
219       // If the input is already on the right device, then nothing to do.
220       if (expected_device != handle_device && maybe_copy) {
221         TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
222                                                      handle, i, handle_device,
223                                                      expected_device, &handle));
224         op->UpdateInput(i, handle);
225         // Unref handle since it has a ref as an input now
226         handle->Unref();
227       }
228       if (handle->dtype != input_types[i]) {
229         return errors::InvalidArgument(
230             "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
231             " was expected to be a ", DataTypeString(input_types[i]),
232             " tensor but is a ", DataTypeString(handle->dtype), " tensor");
233       }
234     }
235   }
236   return Status::OK();
237 }
238 
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)239 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
240   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
241   const OpDef* op_def = nullptr;
242 
243   const FunctionDef* function_def =
244       op->EagerContext().FuncLibDef()->Find(op->Name());
245   if (function_def != nullptr) {
246     op_def = &(function_def->signature());
247   } else {
248     TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
249   }
250 
251   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
252 
253   return Status::OK();
254 }
255 
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)256 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
257                                                const tensorflow::Fprint128& b) {
258   return {tensorflow::FingerprintCat64(a.low64, b.low64),
259           tensorflow::FingerprintCat64(a.high64, b.high64)};
260 }
261 
FingerprintCat128(const tensorflow::Fprint128 & a,const int64 b)262 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
263                                                const int64 b) {
264   auto x = tensorflow::FingerprintCat64(a.low64, b);
265   return {x, tensorflow::FingerprintCat64(a.high64, x)};
266 }
267 
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)268 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
269                          Device** result) {
270   Device* cpu_device = ctx.HostCPU();
271   string device_name;
272   if (tensor_handle->Type() != TensorHandle::LOCAL) {
273     Device* device = tensor_handle->device();
274     device_name = device != nullptr ? device->name() : cpu_device->name();
275     *result = (device == nullptr ? cpu_device : device);
276   } else if (tensor_handle->dtype == DT_RESOURCE) {
277     // Use the resource's actual device because it is the device that will
278     // influence partitioning the multi-device function.
279     const Tensor* tensor;
280     // TODO(fishx): Avoid blocking here.
281     TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
282     const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
283     device_name = handle.device();
284 
285     Device* input_device;
286     TF_RETURN_IF_ERROR(
287         ctx.FindDeviceFromName(device_name.c_str(), &input_device));
288     *result = input_device;
289   } else {
290     Device* device = tensor_handle->device();
291     const bool is_tpu = device != nullptr && device->device_type() == "TPU";
292     // int32 return values can be placed on TPUs.
293     const bool use_host_memory =
294         is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
295                : MTypeFromDType(tensor_handle->dtype);
296     if (use_host_memory) {
297       *result = cpu_device;
298     } else {
299       device_name = device != nullptr ? device->name() : cpu_device->name();
300       *result = (device == nullptr ? cpu_device : device);
301     }
302   }
303   return Status::OK();
304 }
305 
306 // Appends a TensorShape object to Fprint128 hash.
307 // For best performance, we would like to avoid dynamic memory allocation in
308 // this function.
309 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
310 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)311 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
312                                     Fprint128* fingerprint) {
313   if (shape.unknown_rank()) {
314     char c = '?';
315     *fingerprint = FingerprintCat128(*fingerprint, c);
316   } else {
317     for (int i = 0; i < shape.dims(); i++) {
318       int64 dim = shape.dim_size(i);
319       *fingerprint = FingerprintCat128(*fingerprint, dim);
320     }
321   }
322 }
323 
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)324 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
325                    const char* attr_name, bool* value) {
326   Status status = op->Attrs().Get(attr_name, value);
327   if (status.ok()) {
328     DVLOG(2) << "Caller explicitly specifies "
329              << (attr_name ? "=true " : "=false, ") << op->DebugString();
330     return Status::OK();
331   }
332 
333   const FunctionDef* function_def =
334       ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
335   if (function_def == nullptr) {
336     return errors::NotFound("Failed to find function '", op->Name(), "'");
337   }
338 
339   status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
340   if (status.ok()) {
341     DVLOG(2) << "Function definition explicitly specifies "
342              << (attr_name ? "=true" : "=false");
343     return Status::OK();
344   }
345   return status;
346 }
347 
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)348 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
349                           bool* compile_with_xla) {
350   if (!op->is_function()) {
351     *compile_with_xla = false;
352     return Status::OK();
353   }
354 
355   if (op->remote_func_params().has_value() &&
356       op->remote_func_params().value().step_id.has_value()) {
357     // If the op is a component of a multi-device function, don't compile it
358     // with XLA.
359     *compile_with_xla = false;
360     return Status::OK();
361   }
362 
363   Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
364   if (status.ok()) {
365     return Status::OK();
366   }
367 
368   // No explicit requests. Compile for XLA devices by default.
369   if (op->GetDeviceParsedName().type == "TPU" ||
370       op->GetDeviceParsedName().type == "XLA_GPU" ||
371       op->GetDeviceParsedName().type == "XLA_CPU") {
372     DVLOG(2) << "Compiling " << op->Name()
373              << " with XLA because it is running on an XLA device "
374              << op->GetDeviceParsedName().type;
375     *compile_with_xla = true;
376   } else {
377     *compile_with_xla = false;
378   }
379 
380   return Status::OK();
381 }
382 
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)383 Status GetOrCreateKernelAndDevice(
384     EagerOperation* op, TensorHandle** retvals, int* num_retvals,
385     core::RefCountPtr<KernelAndDevice>* out_kernel) {
386   EagerContext& ctx = op->EagerContext();
387   Device* device = absl::get<Device*>(op->Device());
388 
389   Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
390   /// Include soft placement policy in cache key since the placement strategy
391   // can change and thus affect which kernel is picked.
392   cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
393   // The launch-time rendezvous reuse setting is bundled with the kernel, so we
394   // need to include it in the cache key.
395   cache_key =
396       FingerprintCat128(cache_key, ctx.GetReuseRendezvousForFunctions());
397 
398   std::vector<Device*> input_dev_ptrs;
399   absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
400   std::unordered_map<int, DtypeAndPartialTensorShape>
401       input_resource_variable_dtypes_and_shapes;
402   // We can eliminate some overhead by running simple functions using regular
403   // CallOp kernel. However, it is tricky to figure out which functions should
404   // be run using CallOp. Also, currently CallOp runs neither optimization
405   // passes (needed for TPU/XLA) nor grappler.
406   // Here are some cases where a function should be run in multi-device mode:
407   //  - Function takes at least two resources on different devices.
408   //  - Function takes a resource on deviceA and a body op explicitly placed
409   //  on deviceB.
410   //  - Function has a colocation constraint.
411   //  - Function has an explicit device annotation (which might not be using
412   //    full canonical device name) different from op_device. Note that false
413   //    positives are ok.
414   //  - Function has a node or a (node) attribute that can potentially make
415   //    the function multi-device after a rewrite pass (e.g. various XLA/TPU
416   //    special nodes and attributes)
417   if (op->is_function()) {
418     profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
419                                profiler::TraceMeLevel::kInfo);
420     input_dev_ptrs.reserve(op->Inputs().size());
421     const absl::InlinedVector<TensorHandle*, 4>* inputs;
422     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
423     for (int i = 0, end = inputs->size(); i < end; i++) {
424       TensorHandle* input = (*inputs)[i];
425 
426       // Get device for this input, and add it to 'cache_key'.
427       Device* input_device;
428       TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
429       input_dev_ptrs.push_back(input_device);
430       CompositeDevice* composite_device = nullptr;
431       if (ctx.FindCompositeDeviceFromName(input_device->name(),
432                                           &composite_device)
433               .ok()) {
434         composite_devices[input_device->name()] =
435             composite_device->underlying_devices();
436       }
437       cache_key =
438           FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
439 
440       // If input is a ResourceHandle, get its resource handle dtypes and shapes
441       // and add them to 'cache_key'.
442       if (input->dtype == DT_RESOURCE) {
443         // We only care about data type and shape for resource variable inputs.
444         // But we have no way to tell if input is resource variable (other than
445         // looking it up in ResourceMgr, which is slow). So we just get
446         // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
447         // resource_dtypes_and_shapes is not empty, take the first element.
448         std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
449         TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
450             &resource_dtypes_and_shapes));
451         if (!resource_dtypes_and_shapes.empty()) {
452           const DtypeAndPartialTensorShape& dtype_and_shape =
453               resource_dtypes_and_shapes.at(0);
454           input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
455 
456           // Add _Arg index, dtype and shape to "cache_key".
457           cache_key = FingerprintCat128(cache_key, i);
458           DataType dtype = dtype_and_shape.dtype;
459           cache_key = FingerprintCat128(cache_key, dtype);
460           AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
461         }
462       }
463     }
464   }
465 
466   core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
467   if (kernel == nullptr) {
468     DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
469              << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
470     bool run_function_with_flr = false;
471     bool function_outputs_on_op_device = false;
472     if (op->is_function()) {
473       bool compile_with_xla;
474       TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
475       if (compile_with_xla) {
476         // Note that it is not ideal, but currently correct, to set this
477         // attribute after computing the kernel cache key above.
478         // Note: If the attribute is already set to true, this is a noop.
479         op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
480       } else {
481         run_function_with_flr = true;
482       }
483       GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
484           .IgnoreError();
485     }
486 
487     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
488     if (device == nullptr) {
489       TF_RETURN_IF_ERROR(
490           ctx.SelectDevice(op->GetDeviceParsedName(), ndef, &device));
491 
492       DVLOG(1) << "Placer place op [" << op->Name()
493                << "] on device: " << device->name();
494       DVLOG(4) << "Available kernels for " << op->Name() << "are "
495                << KernelsRegisteredForOp(op->Name());
496       op->SetDevice(device);
497     }
498 
499     FunctionLibraryRuntime* flr =
500         device == nullptr ? nullptr : ctx.func_lib(device);
501     if (device != nullptr && flr == nullptr) {
502       return errors::Unavailable(
503           "Unable to find a FunctionLibraryRuntime corresponding to device ",
504           device->name());
505     }
506     auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
507                                                                : ctx.runner();
508     GraphCollector* graph_collector = nullptr;
509     if (ctx.ShouldStoreGraphs()) {
510       graph_collector = ctx.GetGraphCollector();
511     }
512     // Treat the function as multi_device only when we are not compiling
513     // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
514     // will create an XlaLaunchOp kernel to compile and run the function.
515     if (run_function_with_flr) {
516       // Multi-device functions don't use the rendezvous from eager context.
517       // If we use that rendezvous, multiple concurrent calls to the same
518       // function will likely result in collisions. However, this also means
519       // that we don't support legitimate sending/receiving across function
520       // boundary.
521       DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
522                << "Full node_def=" << ndef.DebugString();
523       std::function<int64()> get_op_id = nullptr;
524 #if !defined(IS_MOBILE_PLATFORM)
525       get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
526 #endif  // IS_MOBILE_PLATFORM
527       kernel.reset(new KernelAndDeviceFunc(
528           flr, ctx.pflr(), std::move(input_dev_ptrs),
529           std::move(composite_devices),
530           std::move(input_resource_variable_dtypes_and_shapes), runner,
531           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
532           function_outputs_on_op_device, ctx.RendezvousCreator(), get_op_id));
533     } else {
534       DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
535                << ". Full node_def=" << ndef.DebugString();
536       kernel.reset(new KernelAndDeviceOp(
537           ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
538           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
539     }
540 
541     TF_RETURN_IF_ERROR(
542         kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
543 
544     if (op->is_function()) {
545       ctx.AddKernelToCache(cache_key, kernel.get());
546     } else {
547       // Exclude tf.data op kernels from being cached. The reason for this is
548       // that tf.data op kernels that accept a user-defined function will have a
549       // unique cache key every time they are executed (because the user-defined
550       // function is traced every time). Caching such kernels provides no
551       // benefit and in some cases results in linear memory growth of use
552       // programs that build input pipeline graphs in a loop.
553       const OpDef* op_def;
554       TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
555       if (KernelCacheEnabled(*op_def)) {
556         ctx.AddKernelToCache(cache_key, kernel.get());
557       }
558     }
559   }
560 
561   int num_outputs = kernel->num_outputs();
562   if (num_outputs > *num_retvals) {
563     return errors::InvalidArgument("Expecting ", num_outputs,
564                                    " outputs, but *num_retvals is ",
565                                    *num_retvals);
566   }
567   *num_retvals = num_outputs;
568 
569   kernel->Ref();  // Ownership of reference is passed to out_kernel.
570   out_kernel->reset(kernel.get());
571   return Status::OK();
572 }
573 
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,EagerContext * ctx,TensorHandle ** output)574 Status CreateUnshapedOutput(
575     const KernelAndDevice& kernel, const int output_num, Device* output_device,
576     const DataType& output_dtype,
577     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
578     EagerContext* ctx, TensorHandle** output) {
579 #if defined(IS_MOBILE_PLATFORM)
580   return errors::Unimplemented(
581       "Remote outputs are not available on mobile devices.");
582 #else  // !IS_MOBILE_PLATFORM
583   int64 op_id;
584   if (remote_func_params.has_value()) {
585     op_id = remote_func_params.value().op_id;
586   } else {
587     return errors::InvalidArgument(
588         "Unable to find a remote op id for a remote output of ", kernel.name());
589   }
590   string remote_task;
591   if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
592                                     &remote_task)) {
593     return errors::InvalidArgument(
594         "Unable to find remote task corresponding to device ",
595         output_device->name());
596   }
597   if (ctx->RemoteMgr()->IsMaster()) {
598     *output = TensorHandle::CreateUnshapedRemoteHandle(
599         op_id, output_num, remote_task, output_dtype, output_device, ctx);
600   } else {
601     *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
602                                                    output_dtype, output_device,
603                                                    /*is_ready=*/false, ctx);
604   }
605   return Status::OK();
606 #endif  // !IS_MOBILE_PLATFORM
607 }
608 
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)609 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
610                         EagerOperation* op, TensorHandle** retvals) {
611   EagerExecutor& executor = op->Executor();
612   EagerContext& ctx = op->EagerContext();
613   GraphCollector* graph_collector = nullptr;
614   if (ctx.ShouldStoreGraphs()) {
615     graph_collector = ctx.GetGraphCollector();
616   }
617   const int num_outputs = kernel->num_outputs();
618   absl::optional<EagerRemoteFunctionParams> remote_func_params =
619       op->remote_func_params();
620   if (kernel->IsCrossProcess() && !remote_func_params.has_value()) {
621     // Create an eager op id for a cross-process function if not exist.
622 #if defined(IS_MOBILE_PLATFORM)
623     return errors::Unimplemented(
624         "Cross-process functions are not supported on mobile devices.");
625 #else  // !IS_MOBILE_PLATFORM
626     const int64 op_id = ctx.RemoteMgr()->NextOpId();
627     remote_func_params =
628         EagerRemoteFunctionParams{op_id, /*step_id=*/absl::nullopt};
629 #endif  // !IS_MOBILE_PLATFORM
630   }
631   if (executor.Async()) {
632     const DataTypeVector& output_dtypes = kernel->output_dtypes();
633     for (int i = 0, end = num_outputs; i < end; ++i) {
634       Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
635       if (output_device == nullptr || output_device->IsLocal()) {
636         retvals[i] = TensorHandle::CreateEmptyLocalHandle(
637             /* d= */ output_device, /* op_device= */ kernel->device(),
638             /* resource_device= */ kernel->OutputResourceDevice(i),
639             output_dtypes[i], &ctx);
640       } else {
641         TF_RETURN_IF_ERROR(
642             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
643                                  remote_func_params, &ctx, &retvals[i]));
644       }
645     }
646     const absl::InlinedVector<TensorHandle*, 4>* inputs;
647     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
648     auto node = absl::make_unique<AsyncExecuteNode>(
649         &ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
650         op->GetCancellationManager(),
651         absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
652     // Release the inputs from the eager operation since the AsyncExecuteNode
653     // would have taken ownership. This allows the inputs to be forwarded if
654     // possible.
655     op->Clear();
656     // For async mode, execution order will make sure that all
657     // input handles are ready before executing them.
658     // TODO(b/137118203): Consider executing "cheap" kernels inline for
659     // performance.
660     return executor.AddOrExecute(std::move(node));
661   } else {
662     for (int i = 0, end = num_outputs; i < end; ++i) {
663       retvals[i] = nullptr;
664     }
665     const absl::InlinedVector<TensorHandle*, 4>* inputs;
666     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
667     ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
668                      op->GetCancellationManager(),
669                      {retvals, static_cast<size_t>(num_outputs)},
670                      op->GetStackTrace());
671     Status s = executor.SyncExecute(&node);
672     // We release the inputs AFTER executing the operation in sync mode since
673     // ExecuteNode does not increment the reference count and thus does not have
674     // ownership of the inputs while executing.
675     op->Clear();
676     return s;
677   }
678 }
679 
680 // There are a lot of references to devices in this function and around.
681 // Here is what they mean:
682 //  EagerOperation::Device(): The device on which the user requested the op
683 //    be executed, except if we had to change the device due to resource inputs
684 //    or CPU pinning. If the user did not request a device, the op does not
685 //    take resources, and we did not pin it to CPU, the device can be nullptr.
686 //  KernelAndDevice::Device(): The first time we see an op (combined with
687 //    its attributes), we need to create a KernelAndDevice object for it.
688 //    If op->Device() is a nullptr, we select a device for the op when
689 //    creating the KernelAndDevice. A concrete device will always be selected
690 //    here except when `op` is a function to be executed using function library
691 //    runtime. In this case, we don't select a device because running
692 //    a function with explicitly requested device has different behavior than
693 //    running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)694 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
695                          int* num_retvals) {
696   ScopedMemoryDebugAnnotation op_annotation(
697       op->op_name(), op->remote_func_params().has_value()
698                          ? op->remote_func_params().value().step_id.value_or(0)
699                          : 0);
700   profiler::TraceMe activity(
701       [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
702       profiler::TraceMeLevel::kInfo);
703   EagerContext& ctx = op->EagerContext();
704   auto& executor = op->Executor();
705   TF_RETURN_IF_ERROR(executor.status());
706 
707   core::RefCountPtr<KernelAndDevice> kernel;
708   auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
709 
710   // Run all the registered rewrite pass after the placement, regardless whether
711   // the placement is successful or not. The passes can either create new ops
712   // (without placement) or update some fields of the input op.
713   std::unique_ptr<tensorflow::EagerOperation> out_op;
714   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
715       EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
716   if (out_op) {
717     op = out_op.get();
718     // If the out op doesn't have device, either because it is a new op or
719     // the op wasn't placed successfully, then we do the placement again.
720     if (op->Device() == kVariantDeviceNull) {
721       status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
722     }
723   }
724   if (!status.ok()) return status;
725 
726   int num_outputs = kernel->num_outputs();
727   TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
728 
729   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
730     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
731                                  kernel->device()->name());
732     if (!logging::LogToListeners(msg)) {
733       LOG(INFO) << msg;
734     }
735   }
736 
737   Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
738   // Since the operation failed, we need to Unref any outputs if they were
739   // allocated.
740   if (!s.ok()) {
741     for (int i = 0, end = num_outputs; i < end; ++i) {
742       if (retvals[i] != nullptr) {
743         retvals[i]->Unref();
744       }
745     }
746   }
747 
748   return s;
749 }
750 
751 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
752 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)753 Status MaybePackInputTensor(EagerOperation* op) {
754   if (op->is_function()) {
755     // Functions could take packed TensorHandles as inputs.
756     return Status::OK();
757   }
758   EagerContext& ctx = op->EagerContext();
759   const absl::InlinedVector<TensorHandle*, 4>* inputs;
760   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
761   for (int i = 0; i < inputs->size(); ++i) {
762     TensorHandle* handle = (*inputs)[i];
763     if (handle->Type() == TensorHandle::PACKED) {
764       EagerOperation pack_op(&ctx);
765       TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
766                                        /*remote=*/false, /*executor=*/nullptr));
767       pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
768       pack_op.MutableAttrs()->Set("T", handle->dtype);
769       for (int i = 0; i < handle->NumPackedHandles(); ++i) {
770         tensorflow::TensorHandle* h = nullptr;
771         TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
772         TF_RETURN_IF_ERROR(pack_op.AddInput(h));
773       }
774       int num_retvals = 1;
775       absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
776       TF_RETURN_IF_ERROR(
777           EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
778       tensorflow::TensorHandle* ret = retvals.at(0);
779       op->UpdateInput(i, ret);
780       ret->Unref();
781     }
782   }
783   return Status::OK();
784 }
785 
786 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)787 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
788   EagerContext& ctx = op->EagerContext();
789 
790   remote_op->set_id(ctx.RemoteMgr()->NextOpId());
791   remote_op->set_name(op->Name());
792 
793   op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
794   remote_op->set_device(absl::get<Device*>(op->Device())->name());
795   remote_op->set_is_function(op->is_function());
796 }
797 
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)798 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
799                                     const DataTypeVector& output_dtypes,
800                                     TensorHandle** retvals) {
801   if (remote_op.name() == "VarHandleOp") {
802     if (output_dtypes.size() != 1) {
803       return errors::Internal("VarHandleOp should only have one output.");
804     }
805     if (output_dtypes[0] != DT_RESOURCE) {
806       return errors::Internal(
807           "The output of VarHandleOp should be a DT_RESOURCE.");
808     }
809     AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
810     const AttrValue* dtype;
811     TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
812     const AttrValue* shape;
813     TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
814     retvals[0]->SetResourceHandleDtypeAndShape(
815         {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
816   }
817   return Status::OK();
818 }
819 
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)820 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
821                           int* num_retvals) {
822   EagerContext& ctx = op->EagerContext();
823 
824   // TODO(fishx): Remove following code when lazy tensor copy is ready.
825   if (op->Device() == kVariantDeviceNull) {
826     tensorflow::Device* device = nullptr;
827     string device_name = op->DeviceName();
828     TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
829     op->SetDevice(device);
830   }
831 
832   core::RefCountPtr<eager::EagerClient> eager_client;
833   uint64 context_id = ctx.GetContextId();
834   TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
835   string remote_task;
836   if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
837     return errors::InvalidArgument(
838         "Unable to find remote task corresponding to device ",
839         op->DeviceName());
840   }
841 
842   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
843   request->set_context_id(context_id);
844 
845   eager::Operation* remote_op = request->add_queue()->mutable_operation();
846 
847   tensorflow::Device* op_device = absl::get<Device*>(op->Device());
848   {
849     profiler::TraceMe activity("CopyInputToExpectedDevice",
850                                profiler::TraceMeLevel::kInfo);
851     const bool is_function = op->is_function();
852     const absl::InlinedVector<TensorHandle*, 4>* inputs;
853     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
854     for (int i = 0, end = inputs->size(); i < end; i++) {
855       tensorflow::TensorHandle* input = (*inputs)[i];
856       tensorflow::Device* input_device = input->device();
857       tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
858       const string* input_device_name = &input_device_or_cpu->name();
859       bool serialize_resource_dtype_and_shape = false;
860       if (op_device != input_device &&
861           // If the expected and actual devices are on the same task, don't
862           // explicitly copy, and instead depend on the copy to happen locally
863           // when the op is executed on the device.
864           !ctx.OnSameTask(op_device, input_device)) {
865         if (!is_function || input_device_or_cpu->IsLocal()) {
866           tensorflow::Device* remote_cpu_device;
867           TF_RETURN_IF_ERROR(
868               ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
869           // TODO(b/110044833): It's possible the same tensor gets copied to the
870           // remote device repeatedly.
871           // Always copy to the remote CPU so that the actual device can be
872           // correctly determined after the kernel is selected/instantiated,
873           // since the op might have its inputs on host memory.
874           TensorHandle* handle = input;
875           Device* handle_device = handle->DeviceOrHostCPU(ctx);
876           // If the input is already on the right device, then nothing to do.
877           if (remote_cpu_device != handle_device) {
878             TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
879                 &ctx, op, op_device, handle, i, handle_device,
880                 remote_cpu_device, &handle));
881             op->UpdateInput(i, handle);
882             input = handle;
883             input_device = remote_cpu_device;
884             input_device_name = &remote_cpu_device->name();
885             // Unref handle since it has a ref as an input now
886             handle->Unref();
887           }
888         } else {
889           serialize_resource_dtype_and_shape =
890               (input->dtype == DT_RESOURCE) &&
891               (!input->HasResourceShapeMirror(op_device,
892                                               ctx.GetContextViewId()));
893         }
894       }
895       auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
896       // For a remote component function, a function execution request and an
897       // input generation request may come from different workers. We need to
898       // guarantee that the input generation request is processed before the
899       // function execution request, so wait until the remote input is ready
900       // before sending it to the multi-device function device.
901       const bool wait_until_ready = op->is_function();
902       TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
903           input, wait_until_ready, input_handle, input_device,
904           *input_device_name, serialize_resource_dtype_and_shape));
905       if (!input_handle->resource_dtypes_and_shapes().empty()) {
906         TF_RETURN_IF_ERROR(
907             input->AddResourceShapeMirror(op_device, input_handle->op_id(),
908                                           input_handle->output_num(), &ctx));
909       }
910     }
911   }
912 
913   PrepareRemoteOp(remote_op, op);
914 
915   DataTypeVector output_dtypes;
916   TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
917 
918   const size_t num_outputs = output_dtypes.size();
919   if (num_outputs != *num_retvals) {
920     return errors::InvalidArgument(
921         "num_retvals does not match expected output dtypes");
922   }
923   *num_retvals = num_outputs;
924 
925   const tensorflow::uint64 id = remote_op->id();
926   for (size_t i = 0; i < num_outputs; ++i) {
927     // TODO(nareshmodi): Change the callback to instead add the decref to a
928     // list of pending decrefs that we can send as a batch with the next
929     // execute.
930 
931     // The device_ and resource_device_ of this TensorHandle might be
932     // incorrect. For multi-device functions, we don't know the output device
933     // until the function is instantiated on a remote worker. Luckily, we don't
934     // need to know the correct remote device here. We just need to know that it
935     // is remote. If we need copy this tensor to this process or run any ops
936     // which take this tensor as an input, block until the correct device is
937     // set.
938     const bool unknown_device = op->is_function();
939     retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
940         id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
941   }
942 
943   // Store the data type and shape of a remote resource variable on the
944   // corresponding remote TensorHandle (output of 'VarHandleOp').
945   // If the variable is an input of a remote function, the function may need
946   // the type and shape during function instantiation. Store the type and
947   // shape on eager master and sent them to the default function device along
948   // with the EnqueueRequest.
949   TF_RETURN_IF_ERROR(
950       StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
951 
952   auto& executor = op->Executor();
953   DVLOG(4) << "Execute remote eager op: " << op->Name()
954            << " (is async?: " << executor.Async() << ").";
955 
956   const absl::InlinedVector<TensorHandle*, 4>* inputs;
957   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
958 
959   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
960       &op->EagerContext(), std::move(request), op_device,
961       ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
962       op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
963       *inputs, {retvals, num_outputs}));
964 
965   if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
966     string msg = strings::StrCat(
967         "Executing op ", op->Name(), " on task ",
968         DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
969     if (!logging::LogToListeners(msg)) {
970       LOG(INFO) << msg;
971     }
972   }
973 
974   Status s = executor.AddOrExecute(std::move(node));
975   // Since the operation failed, we need to Unref any outputs that were
976   // allocated.
977   if (!s.ok()) {
978     for (size_t i = 0; i < num_outputs; ++i) {
979       retvals[i]->Unref();
980     }
981   }
982 
983   return s;
984 }
985 #endif  // IS_MOBILE_PLATFORM
986 
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerRemoteFunctionParams> & remote_func_params)987 Status GetKernelOutputs(
988     std::vector<EagerKernelRet>* outputs, int num_outputs,
989     TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
990     const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
991   for (int i = 0, end = num_outputs; i < end; ++i) {
992     if (retvals[i] == nullptr) {
993       EagerKernelRet& ret = (*outputs)[i];
994       Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
995       if (ret.index() == 0) {
996         retvals[i] = TensorHandle::CreateLocalHandle(
997             std::move(absl::get<Tensor>(ret)),
998             /* d= */ output_device,
999             /* op_device= */ kernel->device(),
1000             /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1001       } else {
1002         const DataTypeVector& output_dtypes = kernel->output_dtypes();
1003         TF_RETURN_IF_ERROR(
1004             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1005                                  remote_func_params, ctx, &retvals[i]));
1006 #if !defined(IS_MOBILE_PLATFORM)
1007         TF_RETURN_IF_ERROR(
1008             retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1009                                        output_device, ctx->GetContextViewId()));
1010 #endif  // IS_MOBILE_PLATFORM
1011       }
1012     } else {
1013       if (!kernel->IsFunction() &&
1014           TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1015         return errors::Internal(
1016             "Kernel output tensor handle has a different op device than the "
1017             "kernel. This should never happen.");
1018       }
1019       if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1020                            retvals[i]->device())) {
1021         return errors::Internal(
1022             "Kernel output tensor handle locates on a different device than "
1023             "the specified kernel output device. This should never happen.");
1024       }
1025 
1026       EagerKernelRet& ret = (*outputs)[i];
1027       if (ret.index() == 0) {
1028         TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1029             std::move(absl::get<Tensor>(ret)),
1030             ctx->CanonicalDevice(kernel->OutputDevice(i))));
1031       } else {
1032 #if defined(IS_MOBILE_PLATFORM)
1033         return errors::Unimplemented(
1034             "Remote outputs are not available on mobile devices.");
1035 #else  // !IS_MOBILE_PLATFORM
1036         TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1037             absl::get<TensorShape>(ret), retvals[i]->device(),
1038             ctx->GetContextViewId()));
1039 #endif  // !IS_MOBILE_PLATFORM
1040       }
1041     }
1042   }
1043   return Status::OK();
1044 }
1045 
CollectGraphs(EagerContext * ctx)1046 void CollectGraphs(EagerContext* ctx) {
1047   mutex_lock ml(*ctx->MetadataMu());
1048 
1049   GraphCollector* collector = ctx->GetGraphCollector();
1050   mutex_lock mll(collector->mu);
1051 
1052   // Adding to partition graphs for backward compatibility.
1053   for (const auto& graph : collector->partitioned_graphs) {
1054     *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1055   }
1056 
1057   if (collector->dirty) {
1058     auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1059     *function_graphs->mutable_post_optimization_graph() =
1060         collector->optimized_graph;
1061     *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1062     for (const auto& graph : collector->partitioned_graphs) {
1063       *function_graphs->add_partition_graphs() = graph;
1064     }
1065   }
1066 
1067   collector->ClearGraphs();
1068 }
1069 }  // namespace
1070 
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1071 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1072                     int* num_retvals) {
1073   profiler::TraceMe activity(
1074       [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
1075       profiler::TraceMeLevel::kInfo);
1076 
1077   if (!op->Executor().Async()) {
1078     // In sync mode, always clear error to maintain the same behavior as before.
1079     // TODO(b/141004939): Remove this.
1080     op->Executor().ClearError();
1081   }
1082 
1083   std::unique_ptr<tensorflow::EagerOperation> out_op;
1084   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1085       EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1086 
1087   if (op->IsLocal()) {
1088     if (out_op) {
1089       op = out_op.get();
1090     }
1091     TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1092     return EagerLocalExecute(op, retvals, num_retvals);
1093   }
1094 
1095 #if defined(IS_MOBILE_PLATFORM)
1096   return errors::Unimplemented(
1097       "Eager's remote execution is not available on mobile devices.");
1098 #else   // !IS_MOBILE_PLATFORM
1099   if (out_op) {
1100     op = out_op.get();
1101   }
1102   return EagerRemoteExecute(op, retvals, num_retvals);
1103 #endif  // !IS_MOBILE_PLATFORM
1104 }
1105 
1106 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1107 Status EagerKernelExecute(
1108     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1109     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1110     const core::RefCountPtr<KernelAndDevice>& kernel,
1111     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1112     absl::Span<TensorHandle*> retvals,
1113     const absl::optional<ManagedStackTrace>& stack_trace) {
1114   profiler::TraceMe activity("EagerKernelExecute",
1115                              profiler::TraceMeLevel::kInfo);
1116   std::vector<EagerKernelRet> outputs(1);
1117 
1118   ExecuteNodeArgs inputs(op_inputs.size());
1119   TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1120   // TODO(apassos) figure out how to record stats for ops which are a part of
1121   // functions.
1122   // TODO(b/111859745): When we support recovering from kernel/device errors, we
1123   // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1124   // device. We don't call it now because it is an unneeded overhead (it
1125   // acquires a lock) and we can't recover from errors anyway.
1126   ScopedStepContainer* container = ctx->StepContainer();
1127   TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1128                                  cancellation_manager, remote_func_params,
1129                                  stack_trace));
1130   if (graph_collector != nullptr) {
1131     CollectGraphs(ctx);
1132   }
1133 
1134   if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1135     return errors::Internal(
1136         "EagerKernelExecute returns a list of ", outputs.size(),
1137         " tensors but ", retvals.size(),
1138         " is expected. This should never "
1139         "happen. Please file a bug with the TensorFlow team.");
1140   }
1141   return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1142                           kernel.get(), remote_func_params);
1143 }
1144 
1145 namespace {
1146 
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1147 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1148                               EagerExecutor* executor, Device* dstd,
1149                               bool mirror, TensorHandle** result) {
1150   TF_RETURN_IF_ERROR(executor->status());
1151   Device* d = ctx->CanonicalDevice(dstd);
1152   if (mirror && h->HasLocalMirror(d)) {
1153     h->Ref();
1154     *result = h;
1155     return Status::OK();
1156   }
1157 
1158   bool async = executor->Async();
1159   if (mirror) {
1160     h->Ref();
1161     *result = h;
1162 
1163     if (h->HasLocalMirror(d)) {
1164       return Status::OK();
1165     }
1166 
1167     // We don't bother adding an empty local mirror in sync mode since we'll be
1168     // executing the operation directly and be calling AddLocalMirror. A
1169     // reference count is still needed which will be removed if the operation
1170     // fails.
1171     if (async) {
1172       Status s = h->AddEmptyLocalMirror(d);
1173       if (!s.ok()) {
1174         // If a mirror was added since we called HasLocalMirror then just return
1175         // since another thread has already added the mirror.
1176         if (s.code() == error::Code::ALREADY_EXISTS) {
1177           return Status::OK();
1178         }
1179 
1180         // Remove the previously added reference count since adding the mirror
1181         // failed.
1182         h->Unref();
1183         *result = nullptr;
1184         return s;
1185       }
1186     }
1187   } else {
1188     *result = TensorHandle::CreateEmptyLocalHandle(
1189         d, dstd, h->resource_device(), h->dtype, ctx);
1190   }
1191 
1192   Status s;
1193   if (async) {
1194     // Note that `h` may not be currently ready. However execution order will
1195     // make sure that `h` is ready before the copy is actually done.
1196     std::unique_ptr<EagerNode> node(
1197         new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1198     s = executor->AddOrExecute(std::move(node));
1199   } else {
1200     CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1201     s = executor->SyncExecute(&node);
1202   }
1203 
1204   // Since the operation failed, we need to Unref any outputs that were
1205   // allocated.
1206   if (!s.ok()) {
1207     (*result)->Unref();
1208   }
1209 
1210   return s;
1211 }
1212 
1213 }  // namespace
1214 
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1215 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1216                          EagerExecutor* executor, Device* device, bool mirror,
1217                          TensorHandle** result) {
1218   TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1219   auto send_device = h->DeviceOrHostCPU(*ctx);
1220   bool sender_is_local = send_device->IsLocal();
1221 
1222   bool receiver_is_local = device->IsLocal();
1223 
1224   if (!executor->Async()) {
1225     // In sync mode, always clear error to maintain the same behavior as before.
1226     // TODO(b/141004939): Remove this.
1227     executor->ClearError();
1228   }
1229 
1230   if (sender_is_local && receiver_is_local) {
1231     return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1232   } else {
1233 #if defined(IS_MOBILE_PLATFORM)
1234     return errors::Unimplemented(
1235         "Eager's remote execution is not available on mobile devices.");
1236 #else   // !IS_MOBILE_PLATFORM
1237     uint64 recv_op_id = 0;
1238     if (receiver_is_local) {
1239       Device* d = ctx->CanonicalDevice(device);
1240       // TODO(gjn): Need to add support for async execution. Note if receiver
1241       // is local, we need to first add support in TensorHandle to wait on local
1242       // mirrors.
1243       if (mirror) {
1244         h->Ref();
1245         *result = h;
1246 
1247         if (h->HasLocalMirror(d)) {
1248           return Status::OK();
1249         }
1250 
1251         Status s = h->AddEmptyLocalMirror(d);
1252         if (!s.ok()) {
1253           // If a mirror was added since we called HasLocalMirror then just
1254           // return since another thread has already added the mirror.
1255           if (s.code() == error::Code::ALREADY_EXISTS) {
1256             return Status::OK();
1257           }
1258 
1259           // Remove the previously added reference count since adding the mirror
1260           // failed.
1261           h->Unref();
1262           *result = nullptr;
1263           return s;
1264         }
1265       } else {
1266         *result = TensorHandle::CreateEmptyLocalHandle(
1267             /* d= */ d, /* op_device= */ device,
1268             /*resource_device=*/nullptr, h->dtype, ctx);
1269       }
1270     } else {
1271       if (mirror) {
1272         if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1273           h->Ref();
1274           *result = h;
1275           return Status::OK();
1276         }
1277       }
1278       string remote_task;
1279       if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1280         return errors::InvalidArgument(
1281             "Unable to find remote task corresponding to device ",
1282             device->name());
1283       }
1284       recv_op_id = ctx->RemoteMgr()->NextOpId();
1285       if (mirror) {
1286         TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
1287                                                       remote_task, ctx));
1288         h->Ref();
1289         *result = h;
1290       } else {
1291         *result = TensorHandle::CreateUnshapedRemoteHandle(
1292             recv_op_id, 0, remote_task, h->dtype, device, ctx);
1293       }
1294     }
1295 
1296     auto node = std::make_unique<eager::RemoteCopyNode>(
1297         ctx, executor, h, result[0], device, recv_op_id);
1298     Status s = executor->AddOrExecute(std::move(node));
1299     if (!s.ok()) {
1300       result[0]->Unref();
1301     }
1302     return s;
1303 #endif  // !IS_MOBILE_PLATFORM
1304   }
1305 }
1306 
1307 namespace {
1308 // Low-level utility function to execute the kernel specified by `kernel` on
1309 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
1310 // Different from `EagerKernelExecute` that ties up the thread until the
1311 // underlying function finishes execute, this function does not block the thread
1312 // and could return before the function execution finishes. The provided
1313 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)1314 void EagerKernelExecuteAsync(
1315     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1316     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1317     const core::RefCountPtr<KernelAndDevice> kernel,
1318     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1319     TensorHandle** retvals, int num_outputs, StatusCallback done) {
1320   auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
1321   auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
1322 
1323   Status s = inputs->Init(ctx, op_inputs, kernel);
1324   if (!s.ok()) {
1325     done(s);
1326     return;
1327   }
1328 
1329   kernel->Ref();  // Ownership of reference is transferred to the callback
1330   kernel->RunAsync(
1331       ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
1332       remote_func_params,
1333       [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
1334        remote_func_params, kernel_raw = kernel.get(),
1335        done = std::move(done)](const Status& s) {
1336         auto wrapped_done = [&](const Status& s) {
1337           kernel_raw->Unref();
1338           done(s);
1339         };
1340         if (!s.ok()) {
1341           wrapped_done(s);
1342           return;
1343         }
1344         if (graph_collector != nullptr) {
1345           CollectGraphs(ctx);
1346         }
1347         DCHECK_EQ(num_outputs, outputs->size());
1348         wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
1349                                       kernel_raw, remote_func_params));
1350       });
1351 }
1352 }  // namespace
1353 
1354 // Low-level utility to run the eager operation on local devices. Different from
1355 // `EagerLocalExecute` which blocks and waits for the finishing the op
1356 // execution, this method does not block the thread and could return before the
1357 // eager operation execution finishes. The provided `StatusCallback` will be
1358 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)1359 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
1360                             int* num_retvals, StatusCallback done) {
1361   if (!op->IsLocal()) {
1362     done(errors::InvalidArgument(
1363         "Remote execution is not supported in async EagerLocalExecuteAsync"));
1364     return;
1365   }
1366 
1367   ScopedMemoryDebugAnnotation op_annotation(
1368       op->op_name(), op->remote_func_params().has_value()
1369                          ? op->remote_func_params().value().step_id.value_or(0)
1370                          : 0);
1371   profiler::TraceMe activity(
1372       [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
1373       profiler::TraceMeLevel::kInfo);
1374   EagerContext& ctx = op->EagerContext();
1375 
1376   core::RefCountPtr<KernelAndDevice> kernel;
1377   Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1378   if (!s.ok()) {
1379     done(s);
1380     return;
1381   }
1382 
1383   int num_outputs = kernel->num_outputs();
1384   s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
1385   if (!s.ok()) {
1386     done(s);
1387     return;
1388   }
1389 
1390   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1391     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1392                                  kernel->device()->name());
1393     if (!logging::LogToListeners(msg)) {
1394       LOG(INFO) << msg;
1395     }
1396   }
1397 
1398   GraphCollector* graph_collector = nullptr;
1399   if (ctx.ShouldStoreGraphs()) {
1400     graph_collector = ctx.GetGraphCollector();
1401   }
1402 
1403   for (int i = 0, end = num_outputs; i < end; ++i) {
1404     const DataTypeVector& output_dtypes = kernel->output_dtypes();
1405     retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1406         /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
1407         /* op_device= */ kernel->device(),
1408         /* resource_device= */ kernel->OutputResourceDevice(i),
1409         output_dtypes[i], &ctx);
1410   }
1411 
1412   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1413   s = op->TensorHandleInputs(&inputs);
1414   if (!s.ok()) {
1415     done(s);
1416     return;
1417   }
1418   EagerKernelExecuteAsync(
1419       &ctx, *inputs, op->remote_func_params(), std::move(kernel),
1420       graph_collector, op->GetCancellationManager(), retvals, num_outputs,
1421       [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
1422         op->Clear();
1423         // Since the operation failed, we need to Unref any outputs if they were
1424         // allocated.
1425         if (!s.ok()) {
1426           for (int i = 0, end = num_outputs; i < end; ++i) {
1427             if (retvals[i] != nullptr) {
1428               retvals[i]->Unref();
1429             }
1430           }
1431         }
1432         done(s);
1433       });
1434 }
1435 }  // namespace tensorflow
1436