• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17 
18 #include <cstddef>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "absl/container/btree_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_replace.h"
26 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/lib/core/refcount.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/platform.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 // clang-format on
39 
40 #include "absl/container/inlined_vector.h"
41 #include "absl/strings/match.h"
42 #include "absl/strings/str_cat.h"
43 #include "absl/types/optional.h"
44 #include "tensorflow/c/tf_tensor_internal.h"
45 #include "tensorflow/compiler/jit/defs.h"
46 #include "tensorflow/core/common_runtime/colocation_graph.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/device_set.h"
49 #include "tensorflow/core/common_runtime/eager/context.h"
50 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
51 #include "tensorflow/core/common_runtime/eager/execute_node.h"
52 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
53 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
54 #include "tensorflow/core/framework/dataset.h"
55 #include "tensorflow/core/framework/function.h"
56 #include "tensorflow/core/framework/logging.h"
57 #include "tensorflow/core/framework/node_def_util.h"
58 #include "tensorflow/core/framework/tensor_reference.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/profiler/lib/traceme.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63 #include "tensorflow/core/util/device_name_utils.h"
64 #if !defined(IS_MOBILE_PLATFORM)
65 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
66 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
67 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
68 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
69 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
70 #endif  // IS_MOBILE_PLATFORM
71 #include "tensorflow/core/framework/step_stats.pb.h"
72 #include "tensorflow/core/framework/tensor.h"
73 #include "tensorflow/core/framework/types.h"
74 #include "tensorflow/core/lib/core/status.h"
75 #include "tensorflow/core/lib/gtl/cleanup.h"
76 #include "tensorflow/core/lib/gtl/flatset.h"
77 #include "tensorflow/core/lib/random/random.h"
78 #include "tensorflow/core/platform/env.h"
79 #include "tensorflow/core/platform/mutex.h"
80 #include "tensorflow/core/util/ptr_util.h"
81 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
82 
83 #ifdef INTEL_MKL
84 #include "tensorflow/core/graph/mkl_graph_util.h"
85 #endif
86 
87 namespace tensorflow {
88 
89 namespace {
90 
DeviceNameOrUnspecified(Device * device)91 const string& DeviceNameOrUnspecified(Device* device) {
92   static string* unspecified_string = new string("<unspecified>");
93   return (device == nullptr) ? *unspecified_string : device->name();
94 }
95 
96 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)97 bool KernelCacheEnabled(const OpDef& op_def) {
98   if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
99     return false;
100   }
101   // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
102   // 5+ kernels to exclude.
103   return true;
104 }
105 
106 // This function expects *handle to point to an existing tensor handle that is
107 // currently on "handle_device", but where the operation expects that input to
108 // reside on "expected_input_device".  The function will arrange for this
109 // transfer to happen and will return OK on success and will storage a new
110 // handle to the equivalent tensor on the correct device in "*result".  Or if an
111 // error is encountered, it will return a non-OK status and set "*result" to
112 // nullptr.
113 //
114 // `op_device` is passed in explicitly because `op->device()` might be
115 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)116 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
117                                  Device* op_device,
118                                  TensorHandle* handle,  // op->Inputs()[i]
119                                  int i, Device* handle_device,
120                                  Device* expected_input_device,
121                                  TensorHandle** result) {
122   // Should only be called when these don't match
123   DCHECK(expected_input_device != handle_device);
124   *result = nullptr;
125   const string& op_device_name = DeviceNameOrUnspecified(op_device);
126 
127   switch (ctx->GetDevicePlacementPolicy()) {
128     case DEVICE_PLACEMENT_SILENT_FOR_INT32:
129       // TODO(xpan): See if we could bubble python related error up
130       // to python level.
131       if (handle->dtype == DT_INT32) {
132         // Note: enabling silent copies of int32 tensors to match behavior
133         // of graph mode.
134         break;
135       }
136       TF_FALLTHROUGH_INTENDED;
137     case DEVICE_PLACEMENT_EXPLICIT:
138       // tf.identity is allowed to copy, as indicated in the error message
139       // below.
140       if (op->Name() == "Identity" ||
141           op->Name() == "IdentityN"
142           // Constants start on CPU:0 and are copied via EagerConst to the
143           // current device.
144           || op->Name() == "_EagerConst") {
145         break;
146       }
147       return errors::InvalidArgument(
148           "Tensors on conflicting devices:"
149           " cannot compute ",
150           op->Name(), " as input #", i, " was expected to be on ",
151           expected_input_device->name(), " but is actually on ",
152           handle_device->name(), " (operation running on ", op_device_name, ")",
153           " Tensors can be copied explicitly using:"
154           " `with tf.device(device_name): x = tf.identity(x)`"
155           " or transparently copied by using"
156           " tf.config.experimental.set_device_policy('silent')."
157           " Copying tensors between devices may slow down your model");
158     case DEVICE_PLACEMENT_WARN:
159       LOG(WARNING) << "before computing " << op->Name() << " input #" << i
160                    << " was expected to be on " << expected_input_device->name()
161                    << " but is actually on " << handle_device->name()
162                    << " (operation running on " << op_device_name
163                    << "). This triggers a copy which can be a performance "
164                       "bottleneck.";
165       break;
166     case DEVICE_PLACEMENT_SILENT:  // Do nothing.
167       break;
168   }
169   // We are only here if the policy is warn or silent copies, so we should
170   // trigger a copy.
171   TensorHandle* result_handle = nullptr;
172   profiler::TraceMe activity(
173       [&] {
174         return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
175                             " to ", expected_input_device->name());
176       },
177       profiler::TraceMeLevel::kInfo);
178   Status status =
179       EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
180                         /* mirror= */ true, &result_handle);
181   activity.Stop();
182   if (!status.ok()) {
183     return Status(
184         status.code(),
185         absl::StrCat("Failed copying input tensor from ", handle_device->name(),
186                      " to ", expected_input_device->name(), " in order to run ",
187                      op->Name(), ": ", status.error_message()));
188   }
189 
190   *result = result_handle;
191 
192   return Status::OK();
193 }
194 
195 // `op_device_name` the name of the device on which the op will run, if any.
196 // For functions running using function library runtime, the device can be
197 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)198 Status ValidateInputTypeAndPlacement(
199     EagerContext* ctx, EagerOperation* op,
200     const core::RefCountPtr<KernelAndDevice>& kernel) {
201   profiler::TraceMe activity("ValidateInputTypeAndPlacement",
202                              profiler::TraceMeLevel::kInfo);
203   const int n_inputs = op->Inputs().size();
204   if (kernel->num_inputs() != n_inputs) {
205     return errors::InvalidArgument("expected ", kernel->num_inputs(),
206                                    " inputs, got ", n_inputs);
207   }
208   const bool is_function = kernel->IsFunction();
209   if (n_inputs > 0) {
210     const DataType* input_types = &kernel->input_dtypes()[0];
211     const absl::InlinedVector<TensorHandle*, 4>* handles;
212     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
213     for (int i = 0; i < n_inputs; ++i) {
214       TensorHandle* handle = (*handles)[i];
215       Device* expected_device = kernel->InputDevice(i);
216       if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
217         // Extract a handle on the op device from a packed input.
218         // This happens when a function is marked for XLA compilation.
219         // MaybePackInputTensor guarantees that a primitive op has no packed
220         // input at this point.
221         for (int j = 0; j < handle->NumPackedHandles(); ++j) {
222           TensorHandle* h = nullptr;
223           TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
224           if ((h->op_device() != nullptr) &&
225               (h->op_device()->name() == op->DeviceName())) {
226             op->UpdateInput(i, h);
227             handle = h;
228             break;
229           }
230         }
231       }
232       Device* handle_device = handle->DeviceOrHostCPU(*ctx);
233       const bool maybe_copy =
234           !is_function || handle->Type() != TensorHandle::REMOTE;
235       // If the input is already on the right device, then nothing to do.
236       if (expected_device != handle_device && maybe_copy) {
237         TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
238                                                      handle, i, handle_device,
239                                                      expected_device, &handle));
240         op->UpdateInput(i, handle);
241         // Unref handle since it has a ref as an input now
242         handle->Unref();
243       }
244       if (handle->dtype != input_types[i]) {
245         return errors::InvalidArgument(
246             "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
247             " was expected to be a ", DataTypeString(input_types[i]),
248             " tensor but is a ", DataTypeString(handle->dtype), " tensor");
249       }
250     }
251   }
252   return Status::OK();
253 }
254 
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)255 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
256   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
257   const OpDef* op_def = nullptr;
258 
259   const FunctionDef* function_def =
260       op->EagerContext().FuncLibDef()->Find(op->Name());
261   if (function_def != nullptr) {
262     op_def = &(function_def->signature());
263   } else {
264     TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
265   }
266 
267   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
268 
269   return Status::OK();
270 }
271 
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)272 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
273                                                const tensorflow::Fprint128& b) {
274   return {tensorflow::FingerprintCat64(a.low64, b.low64),
275           tensorflow::FingerprintCat64(a.high64, b.high64)};
276 }
277 
FingerprintCat128(const tensorflow::Fprint128 & a,const int64_t b)278 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
279                                                const int64_t b) {
280   auto x = tensorflow::FingerprintCat64(a.low64, b);
281   return {x, tensorflow::FingerprintCat64(a.high64, x)};
282 }
283 
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)284 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
285                          Device** result) {
286   Device* cpu_device = ctx.HostCPU();
287   string device_name;
288   if (tensor_handle->Type() != TensorHandle::LOCAL) {
289     Device* device = tensor_handle->device();
290     device_name = device != nullptr ? device->name() : cpu_device->name();
291     *result = (device == nullptr ? cpu_device : device);
292   } else if (tensor_handle->dtype == DT_RESOURCE) {
293     // Use the resource's actual device because it is the device that will
294     // influence partitioning the multi-device function.
295     const Tensor* tensor;
296     // TODO(fishx): Avoid blocking here.
297     TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
298     const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
299     device_name = handle.device();
300 
301     Device* input_device;
302     TF_RETURN_IF_ERROR(
303         ctx.FindDeviceFromName(device_name.c_str(), &input_device));
304     *result = input_device;
305   } else {
306     Device* device = tensor_handle->device();
307     const bool is_tpu = device != nullptr && device->device_type() == "TPU";
308     // int32 return values can be placed on TPUs.
309     const bool use_host_memory =
310         is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
311                : MTypeFromDType(tensor_handle->dtype);
312     if (use_host_memory) {
313       *result = cpu_device;
314     } else {
315       device_name = device != nullptr ? device->name() : cpu_device->name();
316       *result = (device == nullptr ? cpu_device : device);
317     }
318   }
319   return Status::OK();
320 }
321 
322 // Appends a TensorShape object to Fprint128 hash.
323 // For best performance, we would like to avoid dynamic memory allocation in
324 // this function.
325 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
326 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)327 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
328                                     Fprint128* fingerprint) {
329   if (shape.unknown_rank()) {
330     char c = '?';
331     *fingerprint = FingerprintCat128(*fingerprint, c);
332   } else {
333     for (int i = 0; i < shape.dims(); i++) {
334       int64_t dim = shape.dim_size(i);
335       *fingerprint = FingerprintCat128(*fingerprint, dim);
336     }
337   }
338 }
339 
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)340 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
341                    const char* attr_name, bool* value) {
342   Status status = op->Attrs().Get(attr_name, value);
343   if (status.ok()) {
344     DVLOG(2) << "Caller explicitly specifies "
345              << (attr_name ? "=true " : "=false, ") << op->DebugString();
346     return Status::OK();
347   }
348 
349   const FunctionDef* function_def =
350       ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
351   if (function_def == nullptr) {
352     return errors::NotFound("Failed to find function '", op->Name(), "'");
353   }
354 
355   status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
356   if (status.ok()) {
357     DVLOG(2) << "Function definition explicitly specifies "
358              << (attr_name ? "=true" : "=false");
359     return Status::OK();
360   }
361   return status;
362 }
363 
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)364 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
365                           bool* compile_with_xla) {
366   if (!op->is_function()) {
367     *compile_with_xla = false;
368     return Status::OK();
369   }
370 
371   if (op->remote_func_params().has_value() &&
372       op->remote_func_params().value().step_id.has_value()) {
373     // If the op is a component of a multi-device function, don't compile it
374     // with XLA.
375     *compile_with_xla = false;
376     return Status::OK();
377   }
378 
379   Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
380   if (status.ok()) {
381     return Status::OK();
382   }
383 
384   // No explicit requests. Compile for XLA devices by default.
385   if (op->GetDeviceParsedName().type == "TPU" ||
386       op->GetDeviceParsedName().type == "XLA_GPU" ||
387       op->GetDeviceParsedName().type == "XLA_CPU") {
388     DVLOG(2) << "Compiling " << op->Name()
389              << " with XLA because it is running on an XLA device "
390              << op->GetDeviceParsedName().type;
391     *compile_with_xla = true;
392   } else {
393     *compile_with_xla = false;
394   }
395 
396   return Status::OK();
397 }
398 
VerifyWrappableInCallOp(const OpDef & opdef,EagerOperation * op)399 Status VerifyWrappableInCallOp(const OpDef& opdef, EagerOperation* op) {
400   absl::flat_hash_set<string> opdef_attrs;
401   for (const auto& attr : opdef.attr()) {
402     opdef_attrs.insert(attr.name());
403   }
404   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
405   for (const auto& attr : node_def.attr()) {
406     if (opdef_attrs.find(attr.first) == opdef_attrs.end()) {
407       return errors::Unimplemented("EagerOperation: ", op->Name(),
408                                    " has a private attr '", attr.first, "'.");
409     }
410   }
411   return Status::OK();
412 }
413 
414 using ProtoArgListType = protobuf::RepeatedPtrField<OpDef_ArgDef>;
415 
EscapeOrigName(const string & orig_name)416 string EscapeOrigName(const string& orig_name) {
417   // Replace _ with __ in the original name to avoid name conflicts.
418   return absl::StrReplaceAll(orig_name, {{"_", "__"}});
419 }
420 
421 // Variadic args are flattened during wrapping. This utility returns the name
422 // of a flattened arg/attr.
GetFlatName(const string orig_name,int index)423 string GetFlatName(const string orig_name, int index) {
424   return absl::StrCat(EscapeOrigName(orig_name), "_", index);
425 }
426 
427 // Builds the name of the wrapping FunctionDef for an eager op.
428 //
429 // For ops without variadic inputs/outputs, the name is simply __wrapped_OpType.
430 //
431 // For ops with variadic inputs/outputs, the arity of each variadic attr is
432 // encoded in the name. For example:
433 //
434 // IdentityN[T:[DT_FLOAT, DT_INT64]] -> __wrapped__IdentityN_T_2
435 // Concat[N:2, T:DT_FLOAT] -> __wrapped__Concat_N_2
BuildWrappedOpName(EagerOperation * op,const OpDef & opdef,const AbstractOpAttrs * op_attrs,string * name)436 Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef,
437                           const AbstractOpAttrs* op_attrs, string* name) {
438   string fname = absl::StrCat("__wrapped__", EscapeOrigName(op->Name()));
439   // For every variadic arg in `args`, populates `attr_to_len` with
440   // (attr_name, len(arg)).
441   auto FillAttrToLen = [op_attrs, op](
442                            const ProtoArgListType& args,
443                            absl::btree_map<string, int>* attr_to_len) {
444     for (const auto& arg : args) {
445       if (!arg.type_list_attr().empty()) {
446         gtl::InlinedVector<DataType, 4> type_list;
447         TF_RETURN_IF_ERROR(
448             op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
449         (*attr_to_len)[arg.type_list_attr()] = type_list.size();
450       } else if (!arg.number_attr().empty()) {
451         int64_t number_attr;
452         if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
453           return errors::Internal("Unable to read attr ", arg.number_attr(),
454                                   " for op ", op->Name());
455         }
456         (*attr_to_len)[arg.number_attr()] = number_attr;
457       }
458     }
459     return Status::OK();
460   };
461   absl::btree_map<string, int> attr_to_len;
462   TF_RETURN_IF_ERROR(FillAttrToLen(opdef.input_arg(), &attr_to_len));
463   TF_RETURN_IF_ERROR(FillAttrToLen(opdef.output_arg(), &attr_to_len));
464   for (auto& name_len : attr_to_len) {
465     absl::StrAppend(&fname, "_", name_len.first, "_", name_len.second);
466   }
467   *name = fname;
468   return Status::OK();
469 }
470 
471 // Builds the signature of the wrapping FunctionDef for an eager op.
472 //
473 // For ops without variadic inputs/outputs, the signature is the same as the
474 // OpDef of the original op.
475 //
476 // Variadic inputs/outputs get flattened since we do not support executing
477 // functions with variadic signatures.
478 //
479 // TODO(srbs): These examples should be tests.
480 //
481 // Examples:
482 //
483 // Mixed type list:
484 //
485 // op {
486 //   name: "IdentityN"
487 //   input_arg {
488 //     name: "input"
489 //     type_list_attr: "T"
490 //   }
491 //   output_arg {
492 //     name: "output"
493 //     type_list_attr: "T"
494 //   }
495 //   attr {
496 //     name: "T"
497 //     type: "list(type)"
498 //     has_minimum: true
499 //     minimum: 1
500 //   }
501 // }
502 //
503 // With two inputs T=[DT_FLOAT, DT_INT64] would convert to
504 //
505 // op {
506 //   name: "__wrapped__IdentityN_T_2"
507 //   input_arg {
508 //     name: "input_0"
509 //     type_attr: "T_0"
510 //   }
511 //   input_arg {
512 //     name: "input_1"
513 //     type_attr: "T_1"
514 //   }
515 //   output_arg {
516 //     name: "output_0"
517 //     type_attr: "T_0"
518 //   }
519 //   output_arg {
520 //     name: "output_1"
521 //     type_attr: "T_1"
522 //   }
523 //   attr {
524 //     name: "T_0"
525 //     type: "type"
526 //   }
527 //   attr {
528 //     name: "T_1"
529 //     type: "type"
530 //   }
531 //   attr {
532 //     name: "T"
533 //     type: "list(type)"
534 //     has_minimum: true
535 //     minimum: 1
536 //   }
537 // }
538 //
539 // Note that the list(type) attr is preserved so that it can get copied to the
540 // inner op via a placeholder. This allows additional verification.
541 //
542 // Single type list:
543 //
544 // op {
545 //   name: "ConcatV2"
546 //   input_arg {
547 //     name: "values"
548 //     type_attr: "T"
549 //     number_attr: "N"
550 //   }
551 //   attr {
552 //     name: "N"
553 //     type: "int"
554 //     has_minimum: true
555 //     minimum: 2
556 //   }
557 //   attr {
558 //     name: "T"
559 //     type: "type"
560 //   }
561 //   [axis, output, Tidx are simply copied]
562 // }
563 //
564 // With two inputs N=2 would convert to:
565 //
566 // op {
567 //   name: "__wrapped__ConcatV2_N_2"
568 //   input_arg {
569 //     name: "values_0"
570 //     type_attr: "T"
571 //   }
572 //   input_arg {
573 //     name: "values_1"
574 //     type_attr: "T"
575 //   }
576 //   attr {
577 //     name: "N"
578 //     type: "int"
579 //     has_minimum: true
580 //     minimum: 2
581 //   }
582 //   attr {
583 //     name: "T"
584 //     type: "type"
585 //   }
586 //   [axis, output, Tidx are simply copied]
587 // }
588 //
589 // Note that the N attr is preserved so that it can get copied to the
590 // inner op via a placeholder. This allows additional verification.
BuildWrappedOpSignature(EagerOperation * op,const OpDef & opdef,const string & fname,OpDef & signature)591 Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef,
592                                const string& fname, OpDef& signature) {
593   signature = opdef;
594   signature.clear_input_arg();
595   signature.clear_output_arg();
596   signature.set_name(fname);
597   auto op_attrs = op->GetOpAttrs();
598   auto FillSignatureArgs = [op_attrs, op](
599                                const ProtoArgListType& opdef_args,
600                                ProtoArgListType* sig_args,
601                                absl::flat_hash_set<string>& new_attrs) {
602     for (const auto& arg : opdef_args) {
603       if (!arg.type_list_attr().empty()) {
604         gtl::InlinedVector<DataType, 4> type_list;
605         TF_RETURN_IF_ERROR(
606             op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
607         for (size_t i = 0; i < type_list.size(); i++) {
608           auto arg_def = sig_args->Add();
609           arg_def->set_name(GetFlatName(arg.name(), i));
610           auto attr_name = GetFlatName(arg.type_list_attr(), i);
611           new_attrs.insert(attr_name);
612           arg_def->set_type_attr(std::move(attr_name));
613         }
614       } else if (!arg.number_attr().empty()) {
615         int64_t number_attr;
616         if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
617           return errors::Internal("Unable to read attr ", arg.number_attr(),
618                                   " for op ", op->Name());
619         }
620         for (size_t i = 0; i < number_attr; i++) {
621           auto arg_def = sig_args->Add();
622           arg_def->set_name(GetFlatName(arg.name(), i));
623           if (!arg.type_attr().empty()) {
624             arg_def->set_type_attr(arg.type_attr());
625           } else {
626             arg_def->set_type(arg.type());
627           }
628         }
629       } else {
630         auto arg_def = sig_args->Add();
631         *arg_def = arg;
632         arg_def->set_name(EscapeOrigName(arg.name()));
633         if (!arg.type_attr().empty()) {
634           // Don't escape: type attrs are still referenced by the original name.
635           arg_def->set_type_attr(arg.type_attr());
636         }
637       }
638     }
639     return Status::OK();
640   };
641   absl::flat_hash_set<string> new_attrs;
642   TF_RETURN_IF_ERROR(FillSignatureArgs(
643       opdef.input_arg(), signature.mutable_input_arg(), new_attrs));
644   TF_RETURN_IF_ERROR(FillSignatureArgs(
645       opdef.output_arg(), signature.mutable_output_arg(), new_attrs));
646   for (auto& attr_name : new_attrs) {
647     auto attr_def = signature.mutable_attr()->Add();
648     attr_def->set_name(attr_name);
649     attr_def->set_type("type");
650   }
651   return Status::OK();
652 }
653 
654 // For mixed type inputs "list(type)" we create new attributes in the signature
655 // for each element tensor (See examples in BuildWrappedOpSignature). Here
656 // we construct the values for those attributes and set them on the wrapped op.
AddMixedTypeListAttrs(EagerOperation * wrapped_op,const AbstractOpAttrs * op_attrs,const OpDef & opdef)657 Status AddMixedTypeListAttrs(EagerOperation* wrapped_op,
658                              const AbstractOpAttrs* op_attrs,
659                              const OpDef& opdef) {
660   auto FillAttrsToAdd =
661       [op_attrs](const ProtoArgListType& opdef_args,
662                  absl::flat_hash_map<string, DataType>* attrs_to_add) {
663         for (const auto& arg : opdef_args) {
664           if (!arg.type_list_attr().empty()) {
665             gtl::InlinedVector<DataType, 4> type_list;
666             TF_RETURN_IF_ERROR(
667                 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
668             for (size_t i = 0; i < type_list.size(); i++) {
669               auto attr_name = GetFlatName(arg.type_list_attr(), i);
670               (*attrs_to_add)[attr_name] = type_list[i];
671             }
672           }
673         }
674         return Status::OK();
675       };
676   absl::flat_hash_map<string, DataType> attrs_to_add;
677   TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.input_arg(), &attrs_to_add));
678   TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.output_arg(), &attrs_to_add));
679   for (auto& name_type : attrs_to_add) {
680     TF_RETURN_IF_ERROR(
681         wrapped_op->SetAttrType(name_type.first.data(), name_type.second));
682   }
683   // TODO(srbs): Rename all original attributes using EscapeOrigName.
684   return Status::OK();
685 }
686 
687 // Maps the op's outputs to the function outputs. Mainly useful for variadic
688 // outputs which need to be flattened.
PopulateRetMap(FunctionDef * fdef,const AbstractOpAttrs * op_attrs,const EagerOperation * op,const OpDef & opdef,const OpDef & signature,const string & node_name)689 Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs,
690                       const EagerOperation* op, const OpDef& opdef,
691                       const OpDef& signature, const string& node_name) {
692   int next_sig_output = 0;
693   for (size_t i = 0; i < opdef.output_arg_size(); i++) {
694     const auto& output_arg = opdef.output_arg(i);
695     if (!output_arg.type_list_attr().empty()) {
696       gtl::InlinedVector<DataType, 4> type_list;
697       TF_RETURN_IF_ERROR(
698           op_attrs->GetTypeList(output_arg.type_list_attr(), &type_list));
699       for (int j = 0; j < type_list.size(); j++) {
700         (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
701             absl::StrCat(node_name, ":", output_arg.name(), ":", j);
702       }
703     } else if (!output_arg.number_attr().empty()) {
704       int64_t number_attr;
705       if (!op_attrs->GetInt(output_arg.number_attr(), &number_attr)) {
706         return errors::Internal("Unable to read attr ",
707                                 output_arg.number_attr(), " for op ",
708                                 op->Name());
709       }
710       for (int j = 0; j < number_attr; j++) {
711         (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
712             absl::StrCat(node_name, ":", output_arg.name(), ":", j);
713       }
714     } else {
715       (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
716           absl::StrCat(node_name, ":", output_arg.name(), ":0");
717     }
718   }
719   return Status::OK();
720 }
721 
WrapInCallOp(EagerOperation * op,EagerOperation ** wrapped_op)722 Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) {
723   DCHECK(!op->is_function());
724   const OpDef& opdef = OpRegistry::Global()->LookUp(op->Name())->op_def;
725   // Raise an error for ops which don't support wrapping yet. This includes
726   // ops with list inputs/outputs and ops with private attrs.
727   // TODO(srbs): Support list inputs/outputs.
728   TF_RETURN_IF_ERROR(VerifyWrappableInCallOp(opdef, op));
729 
730   // Build a FunctionDef containing op as a node and register with context.
731   // TODO(srbs): Here we are unable to distinguish between a FunctionDef for
732   // a wrapped eager op and an existing user defined function registered with
733   // the context e.g. with something like
734   // @tf.function
735   // def __wrapped__Add(x, y):
736   //   ...
737   // This can be avoided by introducing a dict in EagerContext that stores a
738   // mapping from the eager op's name to its unique FunctionDef name.
739   auto op_attrs = op->GetOpAttrs();
740   string fname;
741   TF_RETURN_IF_ERROR(BuildWrappedOpName(op, opdef, op_attrs, &fname));
742   if (!op->EagerContext().GetFunctionDef(fname)) {
743     FunctionDef fdef;
744     // Set signature.
745     TF_RETURN_IF_ERROR(
746         BuildWrappedOpSignature(op, opdef, fname, *fdef.mutable_signature()));
747     // Add node.
748     NodeDef* ndef = fdef.add_node_def();
749     ndef->set_op(op->Name());
750     ndef->set_name(op->Name());  // This could be anything.
751     const auto& signature = fdef.signature();
752     for (size_t i = 0; i < signature.input_arg_size(); i++) {
753       ndef->add_input(absl::StrCat(fdef.signature().input_arg(i).name(), ":0"));
754     }
755     // TODO(srbs): Private attrs on the op are dropped here and applied to
756     // the call op instead. If this causes problems we might have to copy those
757     // attrs to this ndef. That would require updating fname to contain a hash
758     // of such attributes.
759     for (const auto& attr : opdef.attr()) {
760       (*ndef->mutable_attr())[attr.name()].set_placeholder(attr.name());
761     }
762 
763 #ifdef INTEL_MKL
764     if (IsMklEnabled() &&
765         absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
766       // All MKL eager ops have `_kernel` private attribute that needs to be set
767       // to a fixed label.
768       AttrValue attr_kernel;
769       attr_kernel.set_s(mkl_op_registry::kMklNameChangeOpLabel);
770       (*ndef->mutable_attr()).insert({"_kernel", attr_kernel});
771     }
772 #endif  // INTEL_MKL
773 
774     // Set `ret` map.
775     TF_RETURN_IF_ERROR(
776         PopulateRetMap(&fdef, op_attrs, op, opdef, signature, ndef->name()));
777     VLOG(1) << fdef.DebugString();
778     TF_RETURN_IF_ERROR(op->EagerContext().AddFunctionDef(std::move(fdef)));
779   }
780   // Build the call op.
781   auto& ctx = op->EagerContext();
782   AbstractOperationPtr call_op(ctx.CreateOperation());
783   TF_RETURN_IF_ERROR(call_op->Reset(fname.c_str(), op->DeviceName().c_str()));
784   for (auto t : op->Inputs()) {
785     TF_RETURN_IF_ERROR(call_op->AddInput(t));
786   }
787   TF_RETURN_IF_ERROR(call_op->SetDeviceName(op->DeviceName().c_str()));
788   *wrapped_op = down_cast<EagerOperation*>(call_op.release());
789   // Attributes on the elementary eager operation are applied to the call op and
790   // to the NodeDef inside the FunctionDef. This allows us to have a single
791   // FunctionDef for different attribute values. When the function is
792   // instantiated, these attributes get forwarded to the NodeDef. This is done
793   // by setting the AttrValue.placeholder field for the NodeDef attrs.
794   (*wrapped_op)->AddAttrs(op_attrs);
795   return AddMixedTypeListAttrs(*wrapped_op, op_attrs, opdef);
796 }
797 
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)798 Status GetOrCreateKernelAndDevice(
799     EagerOperation* op, TensorHandle** retvals, int* num_retvals,
800     core::RefCountPtr<KernelAndDevice>* out_kernel) {
801   EagerContext& ctx = op->EagerContext();
802   Device* device = absl::get<Device*>(op->Device());
803 
804   Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
805   /// Include soft placement policy in cache key since the placement strategy
806   // can change and thus affect which kernel is picked.
807   cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
808   // The launch-time rendezvous reuse setting is bundled with the kernel, so we
809   // need to include it in the cache key.
810   cache_key =
811       FingerprintCat128(cache_key, ctx.GetReuseRendezvousForFunctions());
812 
813   std::vector<Device*> input_dev_ptrs;
814   absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
815   std::unordered_map<int, DtypeAndPartialTensorShape>
816       input_resource_variable_dtypes_and_shapes;
817   // We can eliminate some overhead by running simple functions using regular
818   // CallOp kernel. However, it is tricky to figure out which functions should
819   // be run using CallOp. Also, currently CallOp runs neither optimization
820   // passes (needed for TPU/XLA) nor grappler.
821   // Here are some cases where a function should be run in multi-device mode:
822   //  - Function takes at least two resources on different devices.
823   //  - Function takes a resource on deviceA and a body op explicitly placed
824   //  on deviceB.
825   //  - Function has a colocation constraint.
826   //  - Function has an explicit device annotation (which might not be using
827   //    full canonical device name) different from op_device. Note that false
828   //    positives are ok.
829   //  - Function has a node or a (node) attribute that can potentially make
830   //    the function multi-device after a rewrite pass (e.g. various XLA/TPU
831   //    special nodes and attributes)
832   if (op->is_function() || ctx.RunEagerOpAsFunction()) {
833     profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
834                                profiler::TraceMeLevel::kInfo);
835     input_dev_ptrs.reserve(op->Inputs().size());
836     const absl::InlinedVector<TensorHandle*, 4>* inputs;
837     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
838     for (int i = 0, end = inputs->size(); i < end; i++) {
839       TensorHandle* input = (*inputs)[i];
840 
841       // Get device for this input, and add it to 'cache_key'.
842       Device* input_device;
843       TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
844       input_dev_ptrs.push_back(input_device);
845       CompositeDevice* composite_device = nullptr;
846       if (ctx.FindCompositeDeviceFromName(input_device->name(),
847                                           &composite_device)
848               .ok()) {
849         composite_devices[input_device->name()] =
850             composite_device->underlying_devices();
851       }
852       cache_key =
853           FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
854 
855       // If input is a ResourceHandle, get its resource handle dtypes and shapes
856       // and add them to 'cache_key'.
857       if (input->dtype == DT_RESOURCE) {
858         // We only care about data type and shape for resource variable inputs.
859         // But we have no way to tell if input is resource variable (other than
860         // looking it up in ResourceMgr, which is slow). So we just get
861         // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
862         // resource_dtypes_and_shapes is not empty, take the first element.
863         std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
864         TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
865             &resource_dtypes_and_shapes));
866         if (!resource_dtypes_and_shapes.empty()) {
867           const DtypeAndPartialTensorShape& dtype_and_shape =
868               resource_dtypes_and_shapes.at(0);
869           input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
870 
871           // Add _Arg index, dtype and shape to "cache_key".
872           cache_key = FingerprintCat128(cache_key, i);
873           DataType dtype = dtype_and_shape.dtype;
874           cache_key = FingerprintCat128(cache_key, dtype);
875           AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
876         }
877       }
878     }
879   }
880 
881   core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
882   AbstractOperationPtr wrapped_op_releaser;
883   if (kernel == nullptr) {
884     if (ctx.RunEagerOpAsFunction() && !op->is_function()) {
885       EagerOperation* wrapped_op = nullptr;
886       TF_RETURN_IF_ERROR(WrapInCallOp(op, &wrapped_op));
887       DCHECK(wrapped_op);
888       DCHECK(wrapped_op->is_function());
889       wrapped_op_releaser.reset(wrapped_op);
890       op = wrapped_op;
891     }
892     DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
893              << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
894     bool run_function_with_flr = false;
895     bool function_outputs_on_op_device = false;
896     if (op->is_function()) {
897       bool compile_with_xla;
898       TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
899       if (compile_with_xla) {
900         // Note that it is not ideal, but currently correct, to set this
901         // attribute after computing the kernel cache key above.
902         // Note: If the attribute is already set to true, this is a noop.
903         op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
904       } else {
905         run_function_with_flr = true;
906       }
907       GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
908           .IgnoreError();
909     }
910 
911     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
912     if (device == nullptr) {
913       // Here in local execute, set preferred device to be on the local task to
914       // avoid placing op on a remote device with higher priority.
915       const DeviceNameUtils::ParsedName& preferred_device =
916           DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())
917               ? op->GetDeviceParsedName()
918               : DeviceNameUtils::AddressSpace(ctx.HostCPUParsedName());
919       TF_RETURN_IF_ERROR(ctx.SelectDevice(preferred_device, ndef, &device));
920 
921       DVLOG(1) << "Placer place op [" << op->Name()
922                << "] on device: " << device->name();
923       DVLOG(4) << "Available kernels for " << op->Name() << " are"
924                << KernelsRegisteredForOp(op->Name());
925       op->SetDevice(device);
926     }
927 
928     FunctionLibraryRuntime* flr =
929         device == nullptr ? nullptr : ctx.func_lib(device);
930     if (device != nullptr && flr == nullptr) {
931       return errors::NotFound(
932           "Unable to find a FunctionLibraryRuntime corresponding to device ",
933           device->name());
934     }
935     auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
936                                                                : ctx.runner();
937     GraphCollector* graph_collector = nullptr;
938     if (ctx.ShouldStoreGraphs()) {
939       graph_collector = ctx.GetGraphCollector();
940     }
941     // Treat the function as multi_device only when we are not compiling
942     // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
943     // will create an XlaLaunchOp kernel to compile and run the function.
944     if (run_function_with_flr) {
945       // Multi-device functions don't use the rendezvous from eager context.
946       // If we use that rendezvous, multiple concurrent calls to the same
947       // function will likely result in collisions. However, this also means
948       // that we don't support legitimate sending/receiving across function
949       // boundary.
950       DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
951                << "Full node_def=" << ndef.DebugString();
952       std::function<int64()> get_op_id = nullptr;
953 #if !defined(IS_MOBILE_PLATFORM)
954       get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
955 #endif  // IS_MOBILE_PLATFORM
956       kernel.reset(new KernelAndDeviceFunc(
957           flr, ctx.pflr(), std::move(input_dev_ptrs),
958           std::move(composite_devices),
959           std::move(input_resource_variable_dtypes_and_shapes), runner,
960           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
961           function_outputs_on_op_device, ctx.RendezvousCreator(), get_op_id));
962     } else {
963       DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
964                << ". Full node_def=" << ndef.DebugString();
965       kernel.reset(new KernelAndDeviceOp(
966           ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
967           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
968     }
969 
970     TF_RETURN_IF_ERROR(
971         kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
972 
973     if (op->is_function()) {
974       ctx.AddKernelToCache(cache_key, kernel.get());
975     } else {
976       // Exclude tf.data op kernels from being cached. The reason for this is
977       // that tf.data op kernels that accept a user-defined function will have a
978       // unique cache key every time they are executed (because the user-defined
979       // function is traced every time). Caching such kernels provides no
980       // benefit and in some cases results in linear memory growth of use
981       // programs that build input pipeline graphs in a loop.
982       const OpDef* op_def;
983       TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
984       if (KernelCacheEnabled(*op_def)) {
985         ctx.AddKernelToCache(cache_key, kernel.get());
986       }
987     }
988   }
989 
990   int num_outputs = kernel->num_outputs();
991   if (num_outputs > *num_retvals) {
992     return errors::InvalidArgument("Expecting ", num_outputs,
993                                    " outputs, but *num_retvals is ",
994                                    *num_retvals);
995   }
996   *num_retvals = num_outputs;
997 
998   kernel->Ref();  // Ownership of reference is passed to out_kernel.
999   out_kernel->reset(kernel.get());
1000   return Status::OK();
1001 }
1002 
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,EagerContext * ctx,TensorHandle ** output)1003 Status CreateUnshapedOutput(
1004     const KernelAndDevice& kernel, const int output_num, Device* output_device,
1005     const DataType& output_dtype,
1006     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1007     EagerContext* ctx, TensorHandle** output) {
1008 #if defined(IS_MOBILE_PLATFORM)
1009   return errors::Unimplemented(
1010       "Remote outputs are not available on mobile devices.");
1011 #else  // !IS_MOBILE_PLATFORM
1012   int64_t op_id;
1013   if (remote_func_params.has_value()) {
1014     op_id = remote_func_params.value().op_id;
1015   } else {
1016     return errors::InvalidArgument(
1017         "Unable to find a remote op id for a remote output of ", kernel.name());
1018   }
1019   string remote_task;
1020   if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
1021                                     &remote_task)) {
1022     return errors::InvalidArgument(
1023         "Unable to find remote task corresponding to device ",
1024         output_device->name());
1025   }
1026   if (ctx->RemoteMgr()->IsMaster()) {
1027     *output = TensorHandle::CreateUnshapedRemoteHandle(
1028         op_id, output_num, remote_task, output_dtype, output_device, ctx);
1029   } else {
1030     *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
1031                                                    output_dtype, output_device,
1032                                                    /*is_ready=*/false, ctx);
1033   }
1034   return Status::OK();
1035 #endif  // !IS_MOBILE_PLATFORM
1036 }
1037 
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)1038 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
1039                         EagerOperation* op, TensorHandle** retvals) {
1040   EagerExecutor& executor = op->Executor();
1041   EagerContext& ctx = op->EagerContext();
1042   GraphCollector* graph_collector = nullptr;
1043   if (ctx.ShouldStoreGraphs()) {
1044     graph_collector = ctx.GetGraphCollector();
1045   }
1046   const int num_outputs = kernel->num_outputs();
1047   absl::optional<EagerRemoteFunctionParams> remote_func_params =
1048       op->remote_func_params();
1049   if (kernel->IsCrossProcess() && !remote_func_params.has_value()) {
1050     // Create an eager op id for a cross-process function if not exist.
1051 #if defined(IS_MOBILE_PLATFORM)
1052     return errors::Unimplemented(
1053         "Cross-process functions are not supported on mobile devices.");
1054 #else  // !IS_MOBILE_PLATFORM
1055     const int64_t op_id = ctx.RemoteMgr()->NextOpId();
1056     remote_func_params =
1057         EagerRemoteFunctionParams{op_id, /*step_id=*/absl::nullopt};
1058 #endif  // !IS_MOBILE_PLATFORM
1059   }
1060   if (executor.Async()) {
1061     const DataTypeVector& output_dtypes = kernel->output_dtypes();
1062     for (int i = 0, end = num_outputs; i < end; ++i) {
1063       Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
1064       if (output_device == nullptr || output_device->IsLocal()) {
1065         retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1066             /* d= */ output_device, /* op_device= */ kernel->device(),
1067             /* resource_device= */ kernel->OutputResourceDevice(i),
1068             output_dtypes[i], &ctx);
1069       } else {
1070         TF_RETURN_IF_ERROR(
1071             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1072                                  remote_func_params, &ctx, &retvals[i]));
1073       }
1074     }
1075     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1076     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1077     auto node = absl::make_unique<AsyncExecuteNode>(
1078         &ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
1079         op->GetCancellationManager(),
1080         absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
1081     // Release the inputs from the eager operation since the AsyncExecuteNode
1082     // would have taken ownership. This allows the inputs to be forwarded if
1083     // possible.
1084     op->Clear();
1085     // For async mode, execution order will make sure that all
1086     // input handles are ready before executing them.
1087     // TODO(b/137118203): Consider executing "cheap" kernels inline for
1088     // performance.
1089     return executor.AddOrExecute(std::move(node));
1090   } else {
1091     for (int i = 0, end = num_outputs; i < end; ++i) {
1092       retvals[i] = nullptr;
1093     }
1094     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1095     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1096     ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
1097                      op->GetCancellationManager(),
1098                      {retvals, static_cast<size_t>(num_outputs)},
1099                      op->GetStackTrace());
1100     Status s = executor.SyncExecute(&node);
1101     // We release the inputs AFTER executing the operation in sync mode since
1102     // ExecuteNode does not increment the reference count and thus does not have
1103     // ownership of the inputs while executing.
1104     op->Clear();
1105     return s;
1106   }
1107 }
1108 
1109 // There are a lot of references to devices in this function and around.
1110 // Here is what they mean:
1111 //  EagerOperation::Device(): The device on which the user requested the op
1112 //    be executed, except if we had to change the device due to resource inputs
1113 //    or CPU pinning. If the user did not request a device, the op does not
1114 //    take resources, and we did not pin it to CPU, the device can be nullptr.
1115 //  KernelAndDevice::Device(): The first time we see an op (combined with
1116 //    its attributes), we need to create a KernelAndDevice object for it.
1117 //    If op->Device() is a nullptr, we select a device for the op when
1118 //    creating the KernelAndDevice. A concrete device will always be selected
1119 //    here except when `op` is a function to be executed using function library
1120 //    runtime. In this case, we don't select a device because running
1121 //    a function with explicitly requested device has different behavior than
1122 //    running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1123 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
1124                          int* num_retvals) {
1125   ScopedMemoryDebugAnnotation op_annotation(
1126       op->op_name(), op->remote_func_params().has_value()
1127                          ? op->remote_func_params().value().step_id.value_or(0)
1128                          : 0);
1129   profiler::TraceMe activity(
1130       [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
1131       profiler::TraceMeLevel::kInfo);
1132   EagerContext& ctx = op->EagerContext();
1133   auto& executor = op->Executor();
1134   TF_RETURN_IF_ERROR(executor.status());
1135 
1136   core::RefCountPtr<KernelAndDevice> kernel;
1137   auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1138 
1139   // Run all the registered rewrite pass after the placement, regardless whether
1140   // the placement is successful or not. The passes can either create new ops
1141   // (without placement) or update some fields of the input op.
1142   std::unique_ptr<tensorflow::EagerOperation> out_op;
1143   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1144       EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
1145   if (out_op) {
1146     op = out_op.get();
1147     // If the out op doesn't have device, either because it is a new op or
1148     // the op wasn't placed successfully, then we do the placement again.
1149     if (op->Device() == kVariantDeviceNull) {
1150       status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1151     }
1152   }
1153   if (!status.ok()) return status;
1154 
1155   int num_outputs = kernel->num_outputs();
1156   TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
1157 
1158   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1159     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1160                                  kernel->device()->name());
1161     if (!logging::LogToListeners(msg)) {
1162       LOG(INFO) << msg;
1163     }
1164   }
1165 
1166   Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
1167   // Since the operation failed, we need to Unref any outputs if they were
1168   // allocated.
1169   if (!s.ok()) {
1170     for (int i = 0, end = num_outputs; i < end; ++i) {
1171       if (retvals[i] != nullptr) {
1172         retvals[i]->Unref();
1173       }
1174     }
1175   }
1176 
1177   return s;
1178 }
1179 
1180 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
1181 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)1182 Status MaybePackInputTensor(EagerOperation* op) {
1183   if (op->is_function()) {
1184     // Functions could take packed TensorHandles as inputs.
1185     return Status::OK();
1186   }
1187   EagerContext& ctx = op->EagerContext();
1188   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1189   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1190   for (int i = 0; i < inputs->size(); ++i) {
1191     TensorHandle* handle = (*inputs)[i];
1192     if (handle->Type() == TensorHandle::PACKED) {
1193       EagerOperation pack_op(&ctx);
1194       TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
1195                                        /*remote=*/false, /*executor=*/nullptr));
1196       pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
1197       pack_op.MutableAttrs()->Set("T", handle->dtype);
1198       for (int i = 0; i < handle->NumPackedHandles(); ++i) {
1199         tensorflow::TensorHandle* h = nullptr;
1200         TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
1201         TF_RETURN_IF_ERROR(pack_op.AddInput(h));
1202       }
1203       int num_retvals = 1;
1204       absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
1205       TF_RETURN_IF_ERROR(
1206           EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
1207       tensorflow::TensorHandle* ret = retvals.at(0);
1208       op->UpdateInput(i, ret);
1209       ret->Unref();
1210     }
1211   }
1212   return Status::OK();
1213 }
1214 
1215 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)1216 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
1217   EagerContext& ctx = op->EagerContext();
1218 
1219   remote_op->set_id(ctx.RemoteMgr()->NextOpId());
1220   remote_op->set_name(op->Name());
1221 
1222   op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
1223   remote_op->set_device(absl::get<Device*>(op->Device())->name());
1224   remote_op->set_is_function(op->is_function());
1225 }
1226 
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)1227 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
1228                                     const DataTypeVector& output_dtypes,
1229                                     TensorHandle** retvals) {
1230   if (remote_op.name() == "VarHandleOp") {
1231     if (output_dtypes.size() != 1) {
1232       return errors::Internal("VarHandleOp should only have one output.");
1233     }
1234     if (output_dtypes[0] != DT_RESOURCE) {
1235       return errors::Internal(
1236           "The output of VarHandleOp should be a DT_RESOURCE.");
1237     }
1238     AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
1239     const AttrValue* dtype;
1240     TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
1241     const AttrValue* shape;
1242     TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
1243     retvals[0]->SetResourceHandleDtypeAndShape(
1244         {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
1245   }
1246   return Status::OK();
1247 }
1248 
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1249 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
1250                           int* num_retvals) {
1251   EagerContext& ctx = op->EagerContext();
1252 
1253   // TODO(fishx): Remove following code when lazy tensor copy is ready.
1254   if (op->Device() == kVariantDeviceNull) {
1255     tensorflow::Device* device = nullptr;
1256     string device_name = op->DeviceName();
1257     TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
1258     op->SetDevice(device);
1259   }
1260 
1261   core::RefCountPtr<eager::EagerClient> eager_client;
1262   uint64 context_id = ctx.GetContextId();
1263   TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
1264   string remote_task;
1265   if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
1266     return errors::InvalidArgument(
1267         "Unable to find remote task corresponding to device ",
1268         op->DeviceName());
1269   }
1270 
1271   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
1272   request->set_context_id(context_id);
1273 
1274   eager::Operation* remote_op = request->add_queue()->mutable_operation();
1275 
1276   tensorflow::Device* op_device = absl::get<Device*>(op->Device());
1277   {
1278     profiler::TraceMe activity("CopyInputToExpectedDevice",
1279                                profiler::TraceMeLevel::kInfo);
1280     const bool is_function = op->is_function();
1281     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1282     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1283     for (int i = 0, end = inputs->size(); i < end; i++) {
1284       tensorflow::TensorHandle* input = (*inputs)[i];
1285       tensorflow::Device* input_device = input->device();
1286       tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
1287       const string* input_device_name = &input_device_or_cpu->name();
1288       bool serialize_resource_dtype_and_shape = false;
1289       if (op_device != input_device &&
1290           // If the expected and actual devices are on the same task, don't
1291           // explicitly copy, and instead depend on the copy to happen locally
1292           // when the op is executed on the device.
1293           !ctx.OnSameTask(op_device, input_device)) {
1294         if (!is_function || input_device_or_cpu->IsLocal()) {
1295           tensorflow::Device* remote_cpu_device;
1296           TF_RETURN_IF_ERROR(
1297               ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
1298           // Always copy to the remote CPU so that the actual device can be
1299           // correctly determined after the kernel is selected/instantiated,
1300           // since the op might have its inputs on host memory.
1301           TensorHandle* handle = input;
1302           Device* handle_device = handle->DeviceOrHostCPU(ctx);
1303           // If the input is already on the right device, then nothing to do.
1304           if (remote_cpu_device != handle_device) {
1305             TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
1306                 &ctx, op, op_device, handle, i, handle_device,
1307                 remote_cpu_device, &handle));
1308             op->UpdateInput(i, handle);
1309             input = handle;
1310             input_device = remote_cpu_device;
1311             input_device_name = &remote_cpu_device->name();
1312             // Unref handle since it has a ref as an input now
1313             handle->Unref();
1314           }
1315         } else {
1316           serialize_resource_dtype_and_shape =
1317               (input->dtype == DT_RESOURCE) &&
1318               (!input->HasResourceShapeMirror(op_device,
1319                                               ctx.GetContextViewId()));
1320         }
1321       }
1322       auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
1323       // For a remote component function, a function execution request and an
1324       // input generation request may come from different workers. We need to
1325       // guarantee that the input generation request is processed before the
1326       // function execution request, so wait until the remote input is ready
1327       // before sending it to the multi-device function device.
1328       const bool wait_until_ready = op->is_function();
1329       TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
1330           input, wait_until_ready, input_handle, input_device,
1331           *input_device_name, serialize_resource_dtype_and_shape));
1332       if (!input_handle->resource_dtypes_and_shapes().empty()) {
1333         TF_RETURN_IF_ERROR(
1334             input->AddResourceShapeMirror(op_device, input_handle->op_id(),
1335                                           input_handle->output_num(), &ctx));
1336       }
1337     }
1338   }
1339 
1340   PrepareRemoteOp(remote_op, op);
1341 
1342   DataTypeVector output_dtypes;
1343   TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
1344 
1345   const size_t num_outputs = output_dtypes.size();
1346   if (num_outputs != *num_retvals) {
1347     return errors::InvalidArgument(
1348         "num_retvals does not match expected output dtypes");
1349   }
1350   *num_retvals = num_outputs;
1351 
1352   const tensorflow::uint64 id = remote_op->id();
1353   for (size_t i = 0; i < num_outputs; ++i) {
1354     // TODO(nareshmodi): Change the callback to instead add the decref to a
1355     // list of pending decrefs that we can send as a batch with the next
1356     // execute.
1357 
1358     // The device_ and resource_device_ of this TensorHandle might be
1359     // incorrect. For multi-device functions, we don't know the output device
1360     // until the function is instantiated on a remote worker. Luckily, we don't
1361     // need to know the correct remote device here. We just need to know that it
1362     // is remote. If we need copy this tensor to this process or run any ops
1363     // which take this tensor as an input, block until the correct device is
1364     // set.
1365     const bool unknown_device = op->is_function();
1366     retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
1367         id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
1368   }
1369 
1370   // Store the data type and shape of a remote resource variable on the
1371   // corresponding remote TensorHandle (output of 'VarHandleOp').
1372   // If the variable is an input of a remote function, the function may need
1373   // the type and shape during function instantiation. Store the type and
1374   // shape on eager master and sent them to the default function device along
1375   // with the EnqueueRequest.
1376   TF_RETURN_IF_ERROR(
1377       StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
1378 
1379   auto& executor = op->Executor();
1380   DVLOG(4) << "Execute remote eager op: " << op->Name()
1381            << " (is async?: " << executor.Async() << ").";
1382 
1383   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1384   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1385 
1386   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
1387       &op->EagerContext(), std::move(request), op_device,
1388       ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
1389       op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
1390       *inputs, {retvals, num_outputs}));
1391 
1392   if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
1393     string msg = strings::StrCat(
1394         "Executing op ", op->Name(), " on task ",
1395         DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
1396     if (!logging::LogToListeners(msg)) {
1397       LOG(INFO) << msg;
1398     }
1399   }
1400 
1401   Status s = executor.AddOrExecute(std::move(node));
1402   // Since the operation failed, we need to Unref any outputs that were
1403   // allocated.
1404   if (!s.ok()) {
1405     for (size_t i = 0; i < num_outputs; ++i) {
1406       retvals[i]->Unref();
1407     }
1408   }
1409 
1410   return s;
1411 }
1412 #endif  // IS_MOBILE_PLATFORM
1413 
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerRemoteFunctionParams> & remote_func_params)1414 Status GetKernelOutputs(
1415     std::vector<EagerKernelRet>* outputs, int num_outputs,
1416     TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
1417     const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
1418   for (int i = 0, end = num_outputs; i < end; ++i) {
1419     if (retvals[i] == nullptr) {
1420       EagerKernelRet& ret = (*outputs)[i];
1421       Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
1422       if (ret.index() == 0) {
1423         retvals[i] = TensorHandle::CreateLocalHandle(
1424             std::move(absl::get<Tensor>(ret)),
1425             /* d= */ output_device,
1426             /* op_device= */ kernel->device(),
1427             /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1428       } else {
1429         const DataTypeVector& output_dtypes = kernel->output_dtypes();
1430         TF_RETURN_IF_ERROR(
1431             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1432                                  remote_func_params, ctx, &retvals[i]));
1433 #if !defined(IS_MOBILE_PLATFORM)
1434         TF_RETURN_IF_ERROR(
1435             retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1436                                        output_device, ctx->GetContextViewId()));
1437 #endif  // IS_MOBILE_PLATFORM
1438       }
1439     } else {
1440       if (!kernel->IsFunction() &&
1441           TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1442         return errors::Internal(
1443             "Kernel output tensor handle has a different op device than the "
1444             "kernel. This should never happen.");
1445       }
1446       if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1447                            retvals[i]->device())) {
1448         return errors::Internal(
1449             "Kernel output tensor handle locates on a different device than "
1450             "the specified kernel output device. This should never happen.");
1451       }
1452 
1453       EagerKernelRet& ret = (*outputs)[i];
1454       if (ret.index() == 0) {
1455         TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1456             std::move(absl::get<Tensor>(ret)),
1457             ctx->CanonicalDevice(kernel->OutputDevice(i))));
1458       } else {
1459 #if defined(IS_MOBILE_PLATFORM)
1460         return errors::Unimplemented(
1461             "Remote outputs are not available on mobile devices.");
1462 #else  // !IS_MOBILE_PLATFORM
1463         TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1464             absl::get<TensorShape>(ret), retvals[i]->device(),
1465             ctx->GetContextViewId()));
1466 #endif  // !IS_MOBILE_PLATFORM
1467       }
1468     }
1469   }
1470   return Status::OK();
1471 }
1472 
CollectGraphs(EagerContext * ctx)1473 void CollectGraphs(EagerContext* ctx) {
1474   mutex_lock ml(*ctx->MetadataMu());
1475 
1476   GraphCollector* collector = ctx->GetGraphCollector();
1477   mutex_lock mll(collector->mu);
1478 
1479   // Adding to partition graphs for backward compatibility.
1480   for (const auto& graph : collector->partitioned_graphs) {
1481     *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1482   }
1483 
1484   if (collector->dirty) {
1485     auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1486     *function_graphs->mutable_post_optimization_graph() =
1487         collector->optimized_graph;
1488     *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1489     for (const auto& graph : collector->partitioned_graphs) {
1490       *function_graphs->add_partition_graphs() = graph;
1491     }
1492   }
1493 
1494   collector->ClearGraphs();
1495 }
1496 }  // namespace
1497 
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1498 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1499                     int* num_retvals) {
1500   profiler::TraceMe activity(
1501       [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
1502       profiler::TraceMeLevel::kInfo);
1503 
1504   if (!op->Executor().Async()) {
1505     // In sync mode, always clear error to maintain the same behavior as before.
1506     // TODO(b/141004939): Remove this.
1507     op->Executor().ClearError();
1508   }
1509 
1510   std::unique_ptr<tensorflow::EagerOperation> out_op;
1511   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1512       EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1513 
1514   if (op->IsLocal()) {
1515     if (out_op) {
1516       op = out_op.get();
1517     }
1518     TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1519     return EagerLocalExecute(op, retvals, num_retvals);
1520   }
1521 
1522 #if defined(IS_MOBILE_PLATFORM)
1523   return errors::Unimplemented(
1524       "Eager's remote execution is not available on mobile devices.");
1525 #else   // !IS_MOBILE_PLATFORM
1526   if (out_op) {
1527     op = out_op.get();
1528   }
1529   return EagerRemoteExecute(op, retvals, num_retvals);
1530 #endif  // !IS_MOBILE_PLATFORM
1531 }
1532 
1533 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1534 Status EagerKernelExecute(
1535     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1536     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1537     const core::RefCountPtr<KernelAndDevice>& kernel,
1538     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1539     absl::Span<TensorHandle*> retvals,
1540     const absl::optional<ManagedStackTrace>& stack_trace) {
1541   profiler::TraceMe activity("EagerKernelExecute",
1542                              profiler::TraceMeLevel::kInfo);
1543   std::vector<EagerKernelRet> outputs(1);
1544 
1545   ExecuteNodeArgs inputs(op_inputs.size());
1546   TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1547   // TODO(apassos) figure out how to record stats for ops which are a part of
1548   // functions.
1549   // TODO(b/111859745): When we support recovering from kernel/device errors, we
1550   // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1551   // device. We don't call it now because it is an unneeded overhead (it
1552   // acquires a lock) and we can't recover from errors anyway.
1553   ScopedStepContainer* container = ctx->StepContainer();
1554   CoordinationServiceAgent* coord_agent = nullptr;
1555 #if !defined(IS_MOBILE_PLATFORM)
1556   if (ctx->GetDistributedManager() != nullptr)
1557     coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1558 #endif  // !IS_MOBILE_PLATFORM
1559   TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1560                                  cancellation_manager, remote_func_params,
1561                                  stack_trace, coord_agent));
1562   if (graph_collector != nullptr) {
1563     CollectGraphs(ctx);
1564   }
1565 
1566   if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1567     return errors::Internal(
1568         "EagerKernelExecute returns a list of ", outputs.size(),
1569         " tensors but ", retvals.size(),
1570         " is expected. This should never "
1571         "happen. Please file a bug with the TensorFlow team.");
1572   }
1573   return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1574                           kernel.get(), remote_func_params);
1575 }
1576 
1577 namespace {
1578 
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1579 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1580                               EagerExecutor* executor, Device* dstd,
1581                               bool mirror, TensorHandle** result) {
1582   TF_RETURN_IF_ERROR(executor->status());
1583   Device* d = ctx->CanonicalDevice(dstd);
1584   if (mirror && h->HasLocalMirror(d)) {
1585     h->Ref();
1586     *result = h;
1587     return Status::OK();
1588   }
1589 
1590   bool async = executor->Async();
1591   if (mirror) {
1592     h->Ref();
1593     *result = h;
1594 
1595     if (h->HasLocalMirror(d)) {
1596       return Status::OK();
1597     }
1598 
1599     // We don't bother adding an empty local mirror in sync mode since we'll be
1600     // executing the operation directly and be calling AddLocalMirror. A
1601     // reference count is still needed which will be removed if the operation
1602     // fails.
1603     if (async) {
1604       Status s = h->AddEmptyLocalMirror(d);
1605       if (!s.ok()) {
1606         // If a mirror was added since we called HasLocalMirror then just return
1607         // since another thread has already added the mirror.
1608         if (s.code() == error::Code::ALREADY_EXISTS) {
1609           return Status::OK();
1610         }
1611 
1612         // Remove the previously added reference count since adding the mirror
1613         // failed.
1614         h->Unref();
1615         *result = nullptr;
1616         return s;
1617       }
1618     }
1619   } else {
1620     *result = TensorHandle::CreateEmptyLocalHandle(
1621         d, dstd, h->resource_device(), h->dtype, ctx);
1622   }
1623 
1624   Status s;
1625   if (async) {
1626     // Note that `h` may not be currently ready. However execution order will
1627     // make sure that `h` is ready before the copy is actually done.
1628     std::unique_ptr<EagerNode> node(
1629         new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1630     s = executor->AddOrExecute(std::move(node));
1631   } else {
1632     CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1633     s = executor->SyncExecute(&node);
1634   }
1635 
1636   // Since the operation failed, we need to Unref any outputs that were
1637   // allocated.
1638   if (!s.ok()) {
1639     (*result)->Unref();
1640   }
1641 
1642   return s;
1643 }
1644 
1645 }  // namespace
1646 
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1647 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1648                          EagerExecutor* executor, Device* device, bool mirror,
1649                          TensorHandle** result) {
1650   TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1651   auto send_device = h->DeviceOrHostCPU(*ctx);
1652   bool sender_is_local = send_device->IsLocal();
1653 
1654   bool receiver_is_local = device->IsLocal();
1655 
1656   if (!executor->Async()) {
1657     // In sync mode, always clear error to maintain the same behavior as before.
1658     // TODO(b/141004939): Remove this.
1659     executor->ClearError();
1660   }
1661 
1662   if (sender_is_local && receiver_is_local) {
1663     return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1664   } else {
1665 #if defined(IS_MOBILE_PLATFORM)
1666     return errors::Unimplemented(
1667         "Eager's remote execution is not available on mobile devices.");
1668 #else   // !IS_MOBILE_PLATFORM
1669     uint64 recv_op_id = 0;
1670     if (receiver_is_local) {
1671       Device* d = ctx->CanonicalDevice(device);
1672       // TODO(gjn): Need to add support for async execution. Note if receiver
1673       // is local, we need to first add support in TensorHandle to wait on local
1674       // mirrors.
1675       if (mirror) {
1676         h->Ref();
1677         *result = h;
1678 
1679         if (h->HasLocalMirror(d)) {
1680           return Status::OK();
1681         }
1682 
1683         Status s = h->AddEmptyLocalMirror(d);
1684         if (!s.ok()) {
1685           // If a mirror was added since we called HasLocalMirror then just
1686           // return since another thread has already added the mirror.
1687           if (s.code() == error::Code::ALREADY_EXISTS) {
1688             return Status::OK();
1689           }
1690 
1691           // Remove the previously added reference count since adding the mirror
1692           // failed.
1693           h->Unref();
1694           *result = nullptr;
1695           return s;
1696         }
1697       } else {
1698         *result = TensorHandle::CreateEmptyLocalHandle(
1699             /* d= */ d, /* op_device= */ device,
1700             /*resource_device=*/nullptr, h->dtype, ctx);
1701       }
1702     } else {
1703       if (mirror) {
1704         if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1705           h->Ref();
1706           *result = h;
1707           return Status::OK();
1708         }
1709       }
1710       string remote_task;
1711       if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1712         return errors::InvalidArgument(
1713             "Unable to find remote task corresponding to device ",
1714             device->name());
1715       }
1716       recv_op_id = ctx->RemoteMgr()->NextOpId();
1717       if (mirror) {
1718         TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
1719                                                       remote_task, ctx));
1720         h->Ref();
1721         *result = h;
1722       } else {
1723         *result = TensorHandle::CreateUnshapedRemoteHandle(
1724             recv_op_id, 0, remote_task, h->dtype, device, ctx);
1725       }
1726     }
1727 
1728     auto node = std::make_unique<eager::RemoteCopyNode>(
1729         ctx, executor, h, result[0], device, recv_op_id);
1730     Status s = executor->AddOrExecute(std::move(node));
1731     if (!s.ok()) {
1732       result[0]->Unref();
1733     }
1734     return s;
1735 #endif  // !IS_MOBILE_PLATFORM
1736   }
1737 }
1738 
1739 namespace {
1740 // Low-level utility function to execute the kernel specified by `kernel` on
1741 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
1742 // Different from `EagerKernelExecute` that ties up the thread until the
1743 // underlying function finishes execute, this function does not block the thread
1744 // and could return before the function execution finishes. The provided
1745 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)1746 void EagerKernelExecuteAsync(
1747     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1748     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1749     const core::RefCountPtr<KernelAndDevice> kernel,
1750     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1751     TensorHandle** retvals, int num_outputs, StatusCallback done) {
1752   auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
1753   auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
1754 
1755   Status s = inputs->Init(ctx, op_inputs, kernel);
1756   if (!s.ok()) {
1757     done(s);
1758     return;
1759   }
1760   CoordinationServiceAgent* coord_agent = nullptr;
1761 #if !defined(IS_MOBILE_PLATFORM)
1762   if (ctx->GetDistributedManager() != nullptr)
1763     coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1764 #endif  // !IS_MOBILE_PLATFORM
1765 
1766   kernel->Ref();  // Ownership of reference is transferred to the callback
1767   kernel->RunAsync(
1768       ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
1769       remote_func_params, coord_agent,
1770       [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
1771        remote_func_params, kernel_raw = kernel.get(),
1772        done = std::move(done)](const Status& s) {
1773         auto wrapped_done = [&](const Status& s) {
1774           kernel_raw->Unref();
1775           done(s);
1776         };
1777         if (!s.ok()) {
1778           wrapped_done(s);
1779           return;
1780         }
1781         if (graph_collector != nullptr) {
1782           CollectGraphs(ctx);
1783         }
1784         DCHECK_EQ(num_outputs, outputs->size());
1785         wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
1786                                       kernel_raw, remote_func_params));
1787       });
1788 }
1789 }  // namespace
1790 
1791 // Low-level utility to run the eager operation on local devices. Different from
1792 // `EagerLocalExecute` which blocks and waits for the finishing the op
1793 // execution, this method does not block the thread and could return before the
1794 // eager operation execution finishes. The provided `StatusCallback` will be
1795 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)1796 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
1797                             int* num_retvals, StatusCallback done) {
1798   if (!op->IsLocal()) {
1799     done(errors::InvalidArgument(
1800         "Remote execution is not supported in async EagerLocalExecuteAsync"));
1801     return;
1802   }
1803 
1804   ScopedMemoryDebugAnnotation op_annotation(
1805       op->op_name(), op->remote_func_params().has_value()
1806                          ? op->remote_func_params().value().step_id.value_or(0)
1807                          : 0);
1808   profiler::TraceMe activity(
1809       [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
1810       profiler::TraceMeLevel::kInfo);
1811   EagerContext& ctx = op->EagerContext();
1812 
1813   core::RefCountPtr<KernelAndDevice> kernel;
1814   Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1815   if (!s.ok()) {
1816     done(s);
1817     return;
1818   }
1819 
1820   int num_outputs = kernel->num_outputs();
1821   s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
1822   if (!s.ok()) {
1823     done(s);
1824     return;
1825   }
1826 
1827   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1828     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1829                                  kernel->device()->name());
1830     if (!logging::LogToListeners(msg)) {
1831       LOG(INFO) << msg;
1832     }
1833   }
1834 
1835   GraphCollector* graph_collector = nullptr;
1836   if (ctx.ShouldStoreGraphs()) {
1837     graph_collector = ctx.GetGraphCollector();
1838   }
1839 
1840   for (int i = 0, end = num_outputs; i < end; ++i) {
1841     const DataTypeVector& output_dtypes = kernel->output_dtypes();
1842     retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1843         /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
1844         /* op_device= */ kernel->device(),
1845         /* resource_device= */ kernel->OutputResourceDevice(i),
1846         output_dtypes[i], &ctx);
1847   }
1848 
1849   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1850   s = op->TensorHandleInputs(&inputs);
1851   if (!s.ok()) {
1852     done(s);
1853     return;
1854   }
1855   EagerKernelExecuteAsync(
1856       &ctx, *inputs, op->remote_func_params(), std::move(kernel),
1857       graph_collector, op->GetCancellationManager(), retvals, num_outputs,
1858       [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
1859         op->Clear();
1860         // Since the operation failed, we need to Unref any outputs if they were
1861         // allocated.
1862         if (!s.ok()) {
1863           for (int i = 0, end = num_outputs; i < end; ++i) {
1864             if (retvals[i] != nullptr) {
1865               retvals[i]->Unref();
1866             }
1867           }
1868         }
1869         done(s);
1870       });
1871 }
1872 }  // namespace tensorflow
1873