1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17
18 #include <cstddef>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "absl/container/btree_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_replace.h"
26 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/lib/core/refcount.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/platform.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 // clang-format on
39
40 #include "absl/container/inlined_vector.h"
41 #include "absl/strings/match.h"
42 #include "absl/strings/str_cat.h"
43 #include "absl/types/optional.h"
44 #include "tensorflow/c/tf_tensor_internal.h"
45 #include "tensorflow/compiler/jit/defs.h"
46 #include "tensorflow/core/common_runtime/colocation_graph.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/device_set.h"
49 #include "tensorflow/core/common_runtime/eager/context.h"
50 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
51 #include "tensorflow/core/common_runtime/eager/execute_node.h"
52 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
53 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
54 #include "tensorflow/core/framework/dataset.h"
55 #include "tensorflow/core/framework/function.h"
56 #include "tensorflow/core/framework/logging.h"
57 #include "tensorflow/core/framework/node_def_util.h"
58 #include "tensorflow/core/framework/tensor_reference.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/profiler/lib/traceme.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63 #include "tensorflow/core/util/device_name_utils.h"
64 #if !defined(IS_MOBILE_PLATFORM)
65 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
66 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
67 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
68 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
69 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
70 #endif // IS_MOBILE_PLATFORM
71 #include "tensorflow/core/framework/step_stats.pb.h"
72 #include "tensorflow/core/framework/tensor.h"
73 #include "tensorflow/core/framework/types.h"
74 #include "tensorflow/core/lib/core/status.h"
75 #include "tensorflow/core/lib/gtl/cleanup.h"
76 #include "tensorflow/core/lib/gtl/flatset.h"
77 #include "tensorflow/core/lib/random/random.h"
78 #include "tensorflow/core/platform/env.h"
79 #include "tensorflow/core/platform/mutex.h"
80 #include "tensorflow/core/util/ptr_util.h"
81 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
82
83 #ifdef INTEL_MKL
84 #include "tensorflow/core/graph/mkl_graph_util.h"
85 #endif
86
87 namespace tensorflow {
88
89 namespace {
90
DeviceNameOrUnspecified(Device * device)91 const string& DeviceNameOrUnspecified(Device* device) {
92 static string* unspecified_string = new string("<unspecified>");
93 return (device == nullptr) ? *unspecified_string : device->name();
94 }
95
96 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)97 bool KernelCacheEnabled(const OpDef& op_def) {
98 if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
99 return false;
100 }
101 // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
102 // 5+ kernels to exclude.
103 return true;
104 }
105
106 // This function expects *handle to point to an existing tensor handle that is
107 // currently on "handle_device", but where the operation expects that input to
108 // reside on "expected_input_device". The function will arrange for this
109 // transfer to happen and will return OK on success and will storage a new
110 // handle to the equivalent tensor on the correct device in "*result". Or if an
111 // error is encountered, it will return a non-OK status and set "*result" to
112 // nullptr.
113 //
114 // `op_device` is passed in explicitly because `op->device()` might be
115 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)116 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
117 Device* op_device,
118 TensorHandle* handle, // op->Inputs()[i]
119 int i, Device* handle_device,
120 Device* expected_input_device,
121 TensorHandle** result) {
122 // Should only be called when these don't match
123 DCHECK(expected_input_device != handle_device);
124 *result = nullptr;
125 const string& op_device_name = DeviceNameOrUnspecified(op_device);
126
127 switch (ctx->GetDevicePlacementPolicy()) {
128 case DEVICE_PLACEMENT_SILENT_FOR_INT32:
129 // TODO(xpan): See if we could bubble python related error up
130 // to python level.
131 if (handle->dtype == DT_INT32) {
132 // Note: enabling silent copies of int32 tensors to match behavior
133 // of graph mode.
134 break;
135 }
136 TF_FALLTHROUGH_INTENDED;
137 case DEVICE_PLACEMENT_EXPLICIT:
138 // tf.identity is allowed to copy, as indicated in the error message
139 // below.
140 if (op->Name() == "Identity" ||
141 op->Name() == "IdentityN"
142 // Constants start on CPU:0 and are copied via EagerConst to the
143 // current device.
144 || op->Name() == "_EagerConst") {
145 break;
146 }
147 return errors::InvalidArgument(
148 "Tensors on conflicting devices:"
149 " cannot compute ",
150 op->Name(), " as input #", i, " was expected to be on ",
151 expected_input_device->name(), " but is actually on ",
152 handle_device->name(), " (operation running on ", op_device_name, ")",
153 " Tensors can be copied explicitly using:"
154 " `with tf.device(device_name): x = tf.identity(x)`"
155 " or transparently copied by using"
156 " tf.config.experimental.set_device_policy('silent')."
157 " Copying tensors between devices may slow down your model");
158 case DEVICE_PLACEMENT_WARN:
159 LOG(WARNING) << "before computing " << op->Name() << " input #" << i
160 << " was expected to be on " << expected_input_device->name()
161 << " but is actually on " << handle_device->name()
162 << " (operation running on " << op_device_name
163 << "). This triggers a copy which can be a performance "
164 "bottleneck.";
165 break;
166 case DEVICE_PLACEMENT_SILENT: // Do nothing.
167 break;
168 }
169 // We are only here if the policy is warn or silent copies, so we should
170 // trigger a copy.
171 TensorHandle* result_handle = nullptr;
172 profiler::TraceMe activity(
173 [&] {
174 return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
175 " to ", expected_input_device->name());
176 },
177 profiler::TraceMeLevel::kInfo);
178 Status status =
179 EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
180 /* mirror= */ true, &result_handle);
181 activity.Stop();
182 if (!status.ok()) {
183 return Status(
184 status.code(),
185 absl::StrCat("Failed copying input tensor from ", handle_device->name(),
186 " to ", expected_input_device->name(), " in order to run ",
187 op->Name(), ": ", status.error_message()));
188 }
189
190 *result = result_handle;
191
192 return Status::OK();
193 }
194
195 // `op_device_name` the name of the device on which the op will run, if any.
196 // For functions running using function library runtime, the device can be
197 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)198 Status ValidateInputTypeAndPlacement(
199 EagerContext* ctx, EagerOperation* op,
200 const core::RefCountPtr<KernelAndDevice>& kernel) {
201 profiler::TraceMe activity("ValidateInputTypeAndPlacement",
202 profiler::TraceMeLevel::kInfo);
203 const int n_inputs = op->Inputs().size();
204 if (kernel->num_inputs() != n_inputs) {
205 return errors::InvalidArgument("expected ", kernel->num_inputs(),
206 " inputs, got ", n_inputs);
207 }
208 const bool is_function = kernel->IsFunction();
209 if (n_inputs > 0) {
210 const DataType* input_types = &kernel->input_dtypes()[0];
211 const absl::InlinedVector<TensorHandle*, 4>* handles;
212 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
213 for (int i = 0; i < n_inputs; ++i) {
214 TensorHandle* handle = (*handles)[i];
215 Device* expected_device = kernel->InputDevice(i);
216 if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
217 // Extract a handle on the op device from a packed input.
218 // This happens when a function is marked for XLA compilation.
219 // MaybePackInputTensor guarantees that a primitive op has no packed
220 // input at this point.
221 for (int j = 0; j < handle->NumPackedHandles(); ++j) {
222 TensorHandle* h = nullptr;
223 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
224 if ((h->op_device() != nullptr) &&
225 (h->op_device()->name() == op->DeviceName())) {
226 op->UpdateInput(i, h);
227 handle = h;
228 break;
229 }
230 }
231 }
232 Device* handle_device = handle->DeviceOrHostCPU(*ctx);
233 const bool maybe_copy =
234 !is_function || handle->Type() != TensorHandle::REMOTE;
235 // If the input is already on the right device, then nothing to do.
236 if (expected_device != handle_device && maybe_copy) {
237 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
238 handle, i, handle_device,
239 expected_device, &handle));
240 op->UpdateInput(i, handle);
241 // Unref handle since it has a ref as an input now
242 handle->Unref();
243 }
244 if (handle->dtype != input_types[i]) {
245 return errors::InvalidArgument(
246 "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
247 " was expected to be a ", DataTypeString(input_types[i]),
248 " tensor but is a ", DataTypeString(handle->dtype), " tensor");
249 }
250 }
251 }
252 return Status::OK();
253 }
254
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)255 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
256 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
257 const OpDef* op_def = nullptr;
258
259 const FunctionDef* function_def =
260 op->EagerContext().FuncLibDef()->Find(op->Name());
261 if (function_def != nullptr) {
262 op_def = &(function_def->signature());
263 } else {
264 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
265 }
266
267 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
268
269 return Status::OK();
270 }
271
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)272 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
273 const tensorflow::Fprint128& b) {
274 return {tensorflow::FingerprintCat64(a.low64, b.low64),
275 tensorflow::FingerprintCat64(a.high64, b.high64)};
276 }
277
FingerprintCat128(const tensorflow::Fprint128 & a,const int64_t b)278 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
279 const int64_t b) {
280 auto x = tensorflow::FingerprintCat64(a.low64, b);
281 return {x, tensorflow::FingerprintCat64(a.high64, x)};
282 }
283
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)284 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
285 Device** result) {
286 Device* cpu_device = ctx.HostCPU();
287 string device_name;
288 if (tensor_handle->Type() != TensorHandle::LOCAL) {
289 Device* device = tensor_handle->device();
290 device_name = device != nullptr ? device->name() : cpu_device->name();
291 *result = (device == nullptr ? cpu_device : device);
292 } else if (tensor_handle->dtype == DT_RESOURCE) {
293 // Use the resource's actual device because it is the device that will
294 // influence partitioning the multi-device function.
295 const Tensor* tensor;
296 // TODO(fishx): Avoid blocking here.
297 TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
298 const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
299 device_name = handle.device();
300
301 Device* input_device;
302 TF_RETURN_IF_ERROR(
303 ctx.FindDeviceFromName(device_name.c_str(), &input_device));
304 *result = input_device;
305 } else {
306 Device* device = tensor_handle->device();
307 const bool is_tpu = device != nullptr && device->device_type() == "TPU";
308 // int32 return values can be placed on TPUs.
309 const bool use_host_memory =
310 is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
311 : MTypeFromDType(tensor_handle->dtype);
312 if (use_host_memory) {
313 *result = cpu_device;
314 } else {
315 device_name = device != nullptr ? device->name() : cpu_device->name();
316 *result = (device == nullptr ? cpu_device : device);
317 }
318 }
319 return Status::OK();
320 }
321
322 // Appends a TensorShape object to Fprint128 hash.
323 // For best performance, we would like to avoid dynamic memory allocation in
324 // this function.
325 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
326 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)327 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
328 Fprint128* fingerprint) {
329 if (shape.unknown_rank()) {
330 char c = '?';
331 *fingerprint = FingerprintCat128(*fingerprint, c);
332 } else {
333 for (int i = 0; i < shape.dims(); i++) {
334 int64_t dim = shape.dim_size(i);
335 *fingerprint = FingerprintCat128(*fingerprint, dim);
336 }
337 }
338 }
339
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)340 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
341 const char* attr_name, bool* value) {
342 Status status = op->Attrs().Get(attr_name, value);
343 if (status.ok()) {
344 DVLOG(2) << "Caller explicitly specifies "
345 << (attr_name ? "=true " : "=false, ") << op->DebugString();
346 return Status::OK();
347 }
348
349 const FunctionDef* function_def =
350 ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
351 if (function_def == nullptr) {
352 return errors::NotFound("Failed to find function '", op->Name(), "'");
353 }
354
355 status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
356 if (status.ok()) {
357 DVLOG(2) << "Function definition explicitly specifies "
358 << (attr_name ? "=true" : "=false");
359 return Status::OK();
360 }
361 return status;
362 }
363
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)364 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
365 bool* compile_with_xla) {
366 if (!op->is_function()) {
367 *compile_with_xla = false;
368 return Status::OK();
369 }
370
371 if (op->remote_func_params().has_value() &&
372 op->remote_func_params().value().step_id.has_value()) {
373 // If the op is a component of a multi-device function, don't compile it
374 // with XLA.
375 *compile_with_xla = false;
376 return Status::OK();
377 }
378
379 Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
380 if (status.ok()) {
381 return Status::OK();
382 }
383
384 // No explicit requests. Compile for XLA devices by default.
385 if (op->GetDeviceParsedName().type == "TPU" ||
386 op->GetDeviceParsedName().type == "XLA_GPU" ||
387 op->GetDeviceParsedName().type == "XLA_CPU") {
388 DVLOG(2) << "Compiling " << op->Name()
389 << " with XLA because it is running on an XLA device "
390 << op->GetDeviceParsedName().type;
391 *compile_with_xla = true;
392 } else {
393 *compile_with_xla = false;
394 }
395
396 return Status::OK();
397 }
398
VerifyWrappableInCallOp(const OpDef & opdef,EagerOperation * op)399 Status VerifyWrappableInCallOp(const OpDef& opdef, EagerOperation* op) {
400 absl::flat_hash_set<string> opdef_attrs;
401 for (const auto& attr : opdef.attr()) {
402 opdef_attrs.insert(attr.name());
403 }
404 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
405 for (const auto& attr : node_def.attr()) {
406 if (opdef_attrs.find(attr.first) == opdef_attrs.end()) {
407 return errors::Unimplemented("EagerOperation: ", op->Name(),
408 " has a private attr '", attr.first, "'.");
409 }
410 }
411 return Status::OK();
412 }
413
414 using ProtoArgListType = protobuf::RepeatedPtrField<OpDef_ArgDef>;
415
EscapeOrigName(const string & orig_name)416 string EscapeOrigName(const string& orig_name) {
417 // Replace _ with __ in the original name to avoid name conflicts.
418 return absl::StrReplaceAll(orig_name, {{"_", "__"}});
419 }
420
421 // Variadic args are flattened during wrapping. This utility returns the name
422 // of a flattened arg/attr.
GetFlatName(const string orig_name,int index)423 string GetFlatName(const string orig_name, int index) {
424 return absl::StrCat(EscapeOrigName(orig_name), "_", index);
425 }
426
427 // Builds the name of the wrapping FunctionDef for an eager op.
428 //
429 // For ops without variadic inputs/outputs, the name is simply __wrapped_OpType.
430 //
431 // For ops with variadic inputs/outputs, the arity of each variadic attr is
432 // encoded in the name. For example:
433 //
434 // IdentityN[T:[DT_FLOAT, DT_INT64]] -> __wrapped__IdentityN_T_2
435 // Concat[N:2, T:DT_FLOAT] -> __wrapped__Concat_N_2
BuildWrappedOpName(EagerOperation * op,const OpDef & opdef,const AbstractOpAttrs * op_attrs,string * name)436 Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef,
437 const AbstractOpAttrs* op_attrs, string* name) {
438 string fname = absl::StrCat("__wrapped__", EscapeOrigName(op->Name()));
439 // For every variadic arg in `args`, populates `attr_to_len` with
440 // (attr_name, len(arg)).
441 auto FillAttrToLen = [op_attrs, op](
442 const ProtoArgListType& args,
443 absl::btree_map<string, int>* attr_to_len) {
444 for (const auto& arg : args) {
445 if (!arg.type_list_attr().empty()) {
446 gtl::InlinedVector<DataType, 4> type_list;
447 TF_RETURN_IF_ERROR(
448 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
449 (*attr_to_len)[arg.type_list_attr()] = type_list.size();
450 } else if (!arg.number_attr().empty()) {
451 int64_t number_attr;
452 if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
453 return errors::Internal("Unable to read attr ", arg.number_attr(),
454 " for op ", op->Name());
455 }
456 (*attr_to_len)[arg.number_attr()] = number_attr;
457 }
458 }
459 return Status::OK();
460 };
461 absl::btree_map<string, int> attr_to_len;
462 TF_RETURN_IF_ERROR(FillAttrToLen(opdef.input_arg(), &attr_to_len));
463 TF_RETURN_IF_ERROR(FillAttrToLen(opdef.output_arg(), &attr_to_len));
464 for (auto& name_len : attr_to_len) {
465 absl::StrAppend(&fname, "_", name_len.first, "_", name_len.second);
466 }
467 *name = fname;
468 return Status::OK();
469 }
470
471 // Builds the signature of the wrapping FunctionDef for an eager op.
472 //
473 // For ops without variadic inputs/outputs, the signature is the same as the
474 // OpDef of the original op.
475 //
476 // Variadic inputs/outputs get flattened since we do not support executing
477 // functions with variadic signatures.
478 //
479 // TODO(srbs): These examples should be tests.
480 //
481 // Examples:
482 //
483 // Mixed type list:
484 //
485 // op {
486 // name: "IdentityN"
487 // input_arg {
488 // name: "input"
489 // type_list_attr: "T"
490 // }
491 // output_arg {
492 // name: "output"
493 // type_list_attr: "T"
494 // }
495 // attr {
496 // name: "T"
497 // type: "list(type)"
498 // has_minimum: true
499 // minimum: 1
500 // }
501 // }
502 //
503 // With two inputs T=[DT_FLOAT, DT_INT64] would convert to
504 //
505 // op {
506 // name: "__wrapped__IdentityN_T_2"
507 // input_arg {
508 // name: "input_0"
509 // type_attr: "T_0"
510 // }
511 // input_arg {
512 // name: "input_1"
513 // type_attr: "T_1"
514 // }
515 // output_arg {
516 // name: "output_0"
517 // type_attr: "T_0"
518 // }
519 // output_arg {
520 // name: "output_1"
521 // type_attr: "T_1"
522 // }
523 // attr {
524 // name: "T_0"
525 // type: "type"
526 // }
527 // attr {
528 // name: "T_1"
529 // type: "type"
530 // }
531 // attr {
532 // name: "T"
533 // type: "list(type)"
534 // has_minimum: true
535 // minimum: 1
536 // }
537 // }
538 //
539 // Note that the list(type) attr is preserved so that it can get copied to the
540 // inner op via a placeholder. This allows additional verification.
541 //
542 // Single type list:
543 //
544 // op {
545 // name: "ConcatV2"
546 // input_arg {
547 // name: "values"
548 // type_attr: "T"
549 // number_attr: "N"
550 // }
551 // attr {
552 // name: "N"
553 // type: "int"
554 // has_minimum: true
555 // minimum: 2
556 // }
557 // attr {
558 // name: "T"
559 // type: "type"
560 // }
561 // [axis, output, Tidx are simply copied]
562 // }
563 //
564 // With two inputs N=2 would convert to:
565 //
566 // op {
567 // name: "__wrapped__ConcatV2_N_2"
568 // input_arg {
569 // name: "values_0"
570 // type_attr: "T"
571 // }
572 // input_arg {
573 // name: "values_1"
574 // type_attr: "T"
575 // }
576 // attr {
577 // name: "N"
578 // type: "int"
579 // has_minimum: true
580 // minimum: 2
581 // }
582 // attr {
583 // name: "T"
584 // type: "type"
585 // }
586 // [axis, output, Tidx are simply copied]
587 // }
588 //
589 // Note that the N attr is preserved so that it can get copied to the
590 // inner op via a placeholder. This allows additional verification.
BuildWrappedOpSignature(EagerOperation * op,const OpDef & opdef,const string & fname,OpDef & signature)591 Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef,
592 const string& fname, OpDef& signature) {
593 signature = opdef;
594 signature.clear_input_arg();
595 signature.clear_output_arg();
596 signature.set_name(fname);
597 auto op_attrs = op->GetOpAttrs();
598 auto FillSignatureArgs = [op_attrs, op](
599 const ProtoArgListType& opdef_args,
600 ProtoArgListType* sig_args,
601 absl::flat_hash_set<string>& new_attrs) {
602 for (const auto& arg : opdef_args) {
603 if (!arg.type_list_attr().empty()) {
604 gtl::InlinedVector<DataType, 4> type_list;
605 TF_RETURN_IF_ERROR(
606 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
607 for (size_t i = 0; i < type_list.size(); i++) {
608 auto arg_def = sig_args->Add();
609 arg_def->set_name(GetFlatName(arg.name(), i));
610 auto attr_name = GetFlatName(arg.type_list_attr(), i);
611 new_attrs.insert(attr_name);
612 arg_def->set_type_attr(std::move(attr_name));
613 }
614 } else if (!arg.number_attr().empty()) {
615 int64_t number_attr;
616 if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
617 return errors::Internal("Unable to read attr ", arg.number_attr(),
618 " for op ", op->Name());
619 }
620 for (size_t i = 0; i < number_attr; i++) {
621 auto arg_def = sig_args->Add();
622 arg_def->set_name(GetFlatName(arg.name(), i));
623 if (!arg.type_attr().empty()) {
624 arg_def->set_type_attr(arg.type_attr());
625 } else {
626 arg_def->set_type(arg.type());
627 }
628 }
629 } else {
630 auto arg_def = sig_args->Add();
631 *arg_def = arg;
632 arg_def->set_name(EscapeOrigName(arg.name()));
633 if (!arg.type_attr().empty()) {
634 // Don't escape: type attrs are still referenced by the original name.
635 arg_def->set_type_attr(arg.type_attr());
636 }
637 }
638 }
639 return Status::OK();
640 };
641 absl::flat_hash_set<string> new_attrs;
642 TF_RETURN_IF_ERROR(FillSignatureArgs(
643 opdef.input_arg(), signature.mutable_input_arg(), new_attrs));
644 TF_RETURN_IF_ERROR(FillSignatureArgs(
645 opdef.output_arg(), signature.mutable_output_arg(), new_attrs));
646 for (auto& attr_name : new_attrs) {
647 auto attr_def = signature.mutable_attr()->Add();
648 attr_def->set_name(attr_name);
649 attr_def->set_type("type");
650 }
651 return Status::OK();
652 }
653
654 // For mixed type inputs "list(type)" we create new attributes in the signature
655 // for each element tensor (See examples in BuildWrappedOpSignature). Here
656 // we construct the values for those attributes and set them on the wrapped op.
AddMixedTypeListAttrs(EagerOperation * wrapped_op,const AbstractOpAttrs * op_attrs,const OpDef & opdef)657 Status AddMixedTypeListAttrs(EagerOperation* wrapped_op,
658 const AbstractOpAttrs* op_attrs,
659 const OpDef& opdef) {
660 auto FillAttrsToAdd =
661 [op_attrs](const ProtoArgListType& opdef_args,
662 absl::flat_hash_map<string, DataType>* attrs_to_add) {
663 for (const auto& arg : opdef_args) {
664 if (!arg.type_list_attr().empty()) {
665 gtl::InlinedVector<DataType, 4> type_list;
666 TF_RETURN_IF_ERROR(
667 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
668 for (size_t i = 0; i < type_list.size(); i++) {
669 auto attr_name = GetFlatName(arg.type_list_attr(), i);
670 (*attrs_to_add)[attr_name] = type_list[i];
671 }
672 }
673 }
674 return Status::OK();
675 };
676 absl::flat_hash_map<string, DataType> attrs_to_add;
677 TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.input_arg(), &attrs_to_add));
678 TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.output_arg(), &attrs_to_add));
679 for (auto& name_type : attrs_to_add) {
680 TF_RETURN_IF_ERROR(
681 wrapped_op->SetAttrType(name_type.first.data(), name_type.second));
682 }
683 // TODO(srbs): Rename all original attributes using EscapeOrigName.
684 return Status::OK();
685 }
686
687 // Maps the op's outputs to the function outputs. Mainly useful for variadic
688 // outputs which need to be flattened.
PopulateRetMap(FunctionDef * fdef,const AbstractOpAttrs * op_attrs,const EagerOperation * op,const OpDef & opdef,const OpDef & signature,const string & node_name)689 Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs,
690 const EagerOperation* op, const OpDef& opdef,
691 const OpDef& signature, const string& node_name) {
692 int next_sig_output = 0;
693 for (size_t i = 0; i < opdef.output_arg_size(); i++) {
694 const auto& output_arg = opdef.output_arg(i);
695 if (!output_arg.type_list_attr().empty()) {
696 gtl::InlinedVector<DataType, 4> type_list;
697 TF_RETURN_IF_ERROR(
698 op_attrs->GetTypeList(output_arg.type_list_attr(), &type_list));
699 for (int j = 0; j < type_list.size(); j++) {
700 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
701 absl::StrCat(node_name, ":", output_arg.name(), ":", j);
702 }
703 } else if (!output_arg.number_attr().empty()) {
704 int64_t number_attr;
705 if (!op_attrs->GetInt(output_arg.number_attr(), &number_attr)) {
706 return errors::Internal("Unable to read attr ",
707 output_arg.number_attr(), " for op ",
708 op->Name());
709 }
710 for (int j = 0; j < number_attr; j++) {
711 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
712 absl::StrCat(node_name, ":", output_arg.name(), ":", j);
713 }
714 } else {
715 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
716 absl::StrCat(node_name, ":", output_arg.name(), ":0");
717 }
718 }
719 return Status::OK();
720 }
721
WrapInCallOp(EagerOperation * op,EagerOperation ** wrapped_op)722 Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) {
723 DCHECK(!op->is_function());
724 const OpDef& opdef = OpRegistry::Global()->LookUp(op->Name())->op_def;
725 // Raise an error for ops which don't support wrapping yet. This includes
726 // ops with list inputs/outputs and ops with private attrs.
727 // TODO(srbs): Support list inputs/outputs.
728 TF_RETURN_IF_ERROR(VerifyWrappableInCallOp(opdef, op));
729
730 // Build a FunctionDef containing op as a node and register with context.
731 // TODO(srbs): Here we are unable to distinguish between a FunctionDef for
732 // a wrapped eager op and an existing user defined function registered with
733 // the context e.g. with something like
734 // @tf.function
735 // def __wrapped__Add(x, y):
736 // ...
737 // This can be avoided by introducing a dict in EagerContext that stores a
738 // mapping from the eager op's name to its unique FunctionDef name.
739 auto op_attrs = op->GetOpAttrs();
740 string fname;
741 TF_RETURN_IF_ERROR(BuildWrappedOpName(op, opdef, op_attrs, &fname));
742 if (!op->EagerContext().GetFunctionDef(fname)) {
743 FunctionDef fdef;
744 // Set signature.
745 TF_RETURN_IF_ERROR(
746 BuildWrappedOpSignature(op, opdef, fname, *fdef.mutable_signature()));
747 // Add node.
748 NodeDef* ndef = fdef.add_node_def();
749 ndef->set_op(op->Name());
750 ndef->set_name(op->Name()); // This could be anything.
751 const auto& signature = fdef.signature();
752 for (size_t i = 0; i < signature.input_arg_size(); i++) {
753 ndef->add_input(absl::StrCat(fdef.signature().input_arg(i).name(), ":0"));
754 }
755 // TODO(srbs): Private attrs on the op are dropped here and applied to
756 // the call op instead. If this causes problems we might have to copy those
757 // attrs to this ndef. That would require updating fname to contain a hash
758 // of such attributes.
759 for (const auto& attr : opdef.attr()) {
760 (*ndef->mutable_attr())[attr.name()].set_placeholder(attr.name());
761 }
762
763 #ifdef INTEL_MKL
764 if (IsMklEnabled() &&
765 absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
766 // All MKL eager ops have `_kernel` private attribute that needs to be set
767 // to a fixed label.
768 AttrValue attr_kernel;
769 attr_kernel.set_s(mkl_op_registry::kMklNameChangeOpLabel);
770 (*ndef->mutable_attr()).insert({"_kernel", attr_kernel});
771 }
772 #endif // INTEL_MKL
773
774 // Set `ret` map.
775 TF_RETURN_IF_ERROR(
776 PopulateRetMap(&fdef, op_attrs, op, opdef, signature, ndef->name()));
777 VLOG(1) << fdef.DebugString();
778 TF_RETURN_IF_ERROR(op->EagerContext().AddFunctionDef(std::move(fdef)));
779 }
780 // Build the call op.
781 auto& ctx = op->EagerContext();
782 AbstractOperationPtr call_op(ctx.CreateOperation());
783 TF_RETURN_IF_ERROR(call_op->Reset(fname.c_str(), op->DeviceName().c_str()));
784 for (auto t : op->Inputs()) {
785 TF_RETURN_IF_ERROR(call_op->AddInput(t));
786 }
787 TF_RETURN_IF_ERROR(call_op->SetDeviceName(op->DeviceName().c_str()));
788 *wrapped_op = down_cast<EagerOperation*>(call_op.release());
789 // Attributes on the elementary eager operation are applied to the call op and
790 // to the NodeDef inside the FunctionDef. This allows us to have a single
791 // FunctionDef for different attribute values. When the function is
792 // instantiated, these attributes get forwarded to the NodeDef. This is done
793 // by setting the AttrValue.placeholder field for the NodeDef attrs.
794 (*wrapped_op)->AddAttrs(op_attrs);
795 return AddMixedTypeListAttrs(*wrapped_op, op_attrs, opdef);
796 }
797
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)798 Status GetOrCreateKernelAndDevice(
799 EagerOperation* op, TensorHandle** retvals, int* num_retvals,
800 core::RefCountPtr<KernelAndDevice>* out_kernel) {
801 EagerContext& ctx = op->EagerContext();
802 Device* device = absl::get<Device*>(op->Device());
803
804 Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
805 /// Include soft placement policy in cache key since the placement strategy
806 // can change and thus affect which kernel is picked.
807 cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
808 // The launch-time rendezvous reuse setting is bundled with the kernel, so we
809 // need to include it in the cache key.
810 cache_key =
811 FingerprintCat128(cache_key, ctx.GetReuseRendezvousForFunctions());
812
813 std::vector<Device*> input_dev_ptrs;
814 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
815 std::unordered_map<int, DtypeAndPartialTensorShape>
816 input_resource_variable_dtypes_and_shapes;
817 // We can eliminate some overhead by running simple functions using regular
818 // CallOp kernel. However, it is tricky to figure out which functions should
819 // be run using CallOp. Also, currently CallOp runs neither optimization
820 // passes (needed for TPU/XLA) nor grappler.
821 // Here are some cases where a function should be run in multi-device mode:
822 // - Function takes at least two resources on different devices.
823 // - Function takes a resource on deviceA and a body op explicitly placed
824 // on deviceB.
825 // - Function has a colocation constraint.
826 // - Function has an explicit device annotation (which might not be using
827 // full canonical device name) different from op_device. Note that false
828 // positives are ok.
829 // - Function has a node or a (node) attribute that can potentially make
830 // the function multi-device after a rewrite pass (e.g. various XLA/TPU
831 // special nodes and attributes)
832 if (op->is_function() || ctx.RunEagerOpAsFunction()) {
833 profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
834 profiler::TraceMeLevel::kInfo);
835 input_dev_ptrs.reserve(op->Inputs().size());
836 const absl::InlinedVector<TensorHandle*, 4>* inputs;
837 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
838 for (int i = 0, end = inputs->size(); i < end; i++) {
839 TensorHandle* input = (*inputs)[i];
840
841 // Get device for this input, and add it to 'cache_key'.
842 Device* input_device;
843 TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
844 input_dev_ptrs.push_back(input_device);
845 CompositeDevice* composite_device = nullptr;
846 if (ctx.FindCompositeDeviceFromName(input_device->name(),
847 &composite_device)
848 .ok()) {
849 composite_devices[input_device->name()] =
850 composite_device->underlying_devices();
851 }
852 cache_key =
853 FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
854
855 // If input is a ResourceHandle, get its resource handle dtypes and shapes
856 // and add them to 'cache_key'.
857 if (input->dtype == DT_RESOURCE) {
858 // We only care about data type and shape for resource variable inputs.
859 // But we have no way to tell if input is resource variable (other than
860 // looking it up in ResourceMgr, which is slow). So we just get
861 // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
862 // resource_dtypes_and_shapes is not empty, take the first element.
863 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
864 TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
865 &resource_dtypes_and_shapes));
866 if (!resource_dtypes_and_shapes.empty()) {
867 const DtypeAndPartialTensorShape& dtype_and_shape =
868 resource_dtypes_and_shapes.at(0);
869 input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
870
871 // Add _Arg index, dtype and shape to "cache_key".
872 cache_key = FingerprintCat128(cache_key, i);
873 DataType dtype = dtype_and_shape.dtype;
874 cache_key = FingerprintCat128(cache_key, dtype);
875 AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
876 }
877 }
878 }
879 }
880
881 core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
882 AbstractOperationPtr wrapped_op_releaser;
883 if (kernel == nullptr) {
884 if (ctx.RunEagerOpAsFunction() && !op->is_function()) {
885 EagerOperation* wrapped_op = nullptr;
886 TF_RETURN_IF_ERROR(WrapInCallOp(op, &wrapped_op));
887 DCHECK(wrapped_op);
888 DCHECK(wrapped_op->is_function());
889 wrapped_op_releaser.reset(wrapped_op);
890 op = wrapped_op;
891 }
892 DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
893 << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
894 bool run_function_with_flr = false;
895 bool function_outputs_on_op_device = false;
896 if (op->is_function()) {
897 bool compile_with_xla;
898 TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
899 if (compile_with_xla) {
900 // Note that it is not ideal, but currently correct, to set this
901 // attribute after computing the kernel cache key above.
902 // Note: If the attribute is already set to true, this is a noop.
903 op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
904 } else {
905 run_function_with_flr = true;
906 }
907 GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
908 .IgnoreError();
909 }
910
911 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
912 if (device == nullptr) {
913 // Here in local execute, set preferred device to be on the local task to
914 // avoid placing op on a remote device with higher priority.
915 const DeviceNameUtils::ParsedName& preferred_device =
916 DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())
917 ? op->GetDeviceParsedName()
918 : DeviceNameUtils::AddressSpace(ctx.HostCPUParsedName());
919 TF_RETURN_IF_ERROR(ctx.SelectDevice(preferred_device, ndef, &device));
920
921 DVLOG(1) << "Placer place op [" << op->Name()
922 << "] on device: " << device->name();
923 DVLOG(4) << "Available kernels for " << op->Name() << " are"
924 << KernelsRegisteredForOp(op->Name());
925 op->SetDevice(device);
926 }
927
928 FunctionLibraryRuntime* flr =
929 device == nullptr ? nullptr : ctx.func_lib(device);
930 if (device != nullptr && flr == nullptr) {
931 return errors::NotFound(
932 "Unable to find a FunctionLibraryRuntime corresponding to device ",
933 device->name());
934 }
935 auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
936 : ctx.runner();
937 GraphCollector* graph_collector = nullptr;
938 if (ctx.ShouldStoreGraphs()) {
939 graph_collector = ctx.GetGraphCollector();
940 }
941 // Treat the function as multi_device only when we are not compiling
942 // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
943 // will create an XlaLaunchOp kernel to compile and run the function.
944 if (run_function_with_flr) {
945 // Multi-device functions don't use the rendezvous from eager context.
946 // If we use that rendezvous, multiple concurrent calls to the same
947 // function will likely result in collisions. However, this also means
948 // that we don't support legitimate sending/receiving across function
949 // boundary.
950 DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
951 << "Full node_def=" << ndef.DebugString();
952 std::function<int64()> get_op_id = nullptr;
953 #if !defined(IS_MOBILE_PLATFORM)
954 get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
955 #endif // IS_MOBILE_PLATFORM
956 kernel.reset(new KernelAndDeviceFunc(
957 flr, ctx.pflr(), std::move(input_dev_ptrs),
958 std::move(composite_devices),
959 std::move(input_resource_variable_dtypes_and_shapes), runner,
960 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
961 function_outputs_on_op_device, ctx.RendezvousCreator(), get_op_id));
962 } else {
963 DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
964 << ". Full node_def=" << ndef.DebugString();
965 kernel.reset(new KernelAndDeviceOp(
966 ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
967 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
968 }
969
970 TF_RETURN_IF_ERROR(
971 kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
972
973 if (op->is_function()) {
974 ctx.AddKernelToCache(cache_key, kernel.get());
975 } else {
976 // Exclude tf.data op kernels from being cached. The reason for this is
977 // that tf.data op kernels that accept a user-defined function will have a
978 // unique cache key every time they are executed (because the user-defined
979 // function is traced every time). Caching such kernels provides no
980 // benefit and in some cases results in linear memory growth of use
981 // programs that build input pipeline graphs in a loop.
982 const OpDef* op_def;
983 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
984 if (KernelCacheEnabled(*op_def)) {
985 ctx.AddKernelToCache(cache_key, kernel.get());
986 }
987 }
988 }
989
990 int num_outputs = kernel->num_outputs();
991 if (num_outputs > *num_retvals) {
992 return errors::InvalidArgument("Expecting ", num_outputs,
993 " outputs, but *num_retvals is ",
994 *num_retvals);
995 }
996 *num_retvals = num_outputs;
997
998 kernel->Ref(); // Ownership of reference is passed to out_kernel.
999 out_kernel->reset(kernel.get());
1000 return Status::OK();
1001 }
1002
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,EagerContext * ctx,TensorHandle ** output)1003 Status CreateUnshapedOutput(
1004 const KernelAndDevice& kernel, const int output_num, Device* output_device,
1005 const DataType& output_dtype,
1006 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1007 EagerContext* ctx, TensorHandle** output) {
1008 #if defined(IS_MOBILE_PLATFORM)
1009 return errors::Unimplemented(
1010 "Remote outputs are not available on mobile devices.");
1011 #else // !IS_MOBILE_PLATFORM
1012 int64_t op_id;
1013 if (remote_func_params.has_value()) {
1014 op_id = remote_func_params.value().op_id;
1015 } else {
1016 return errors::InvalidArgument(
1017 "Unable to find a remote op id for a remote output of ", kernel.name());
1018 }
1019 string remote_task;
1020 if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
1021 &remote_task)) {
1022 return errors::InvalidArgument(
1023 "Unable to find remote task corresponding to device ",
1024 output_device->name());
1025 }
1026 if (ctx->RemoteMgr()->IsMaster()) {
1027 *output = TensorHandle::CreateUnshapedRemoteHandle(
1028 op_id, output_num, remote_task, output_dtype, output_device, ctx);
1029 } else {
1030 *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
1031 output_dtype, output_device,
1032 /*is_ready=*/false, ctx);
1033 }
1034 return Status::OK();
1035 #endif // !IS_MOBILE_PLATFORM
1036 }
1037
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)1038 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
1039 EagerOperation* op, TensorHandle** retvals) {
1040 EagerExecutor& executor = op->Executor();
1041 EagerContext& ctx = op->EagerContext();
1042 GraphCollector* graph_collector = nullptr;
1043 if (ctx.ShouldStoreGraphs()) {
1044 graph_collector = ctx.GetGraphCollector();
1045 }
1046 const int num_outputs = kernel->num_outputs();
1047 absl::optional<EagerRemoteFunctionParams> remote_func_params =
1048 op->remote_func_params();
1049 if (kernel->IsCrossProcess() && !remote_func_params.has_value()) {
1050 // Create an eager op id for a cross-process function if not exist.
1051 #if defined(IS_MOBILE_PLATFORM)
1052 return errors::Unimplemented(
1053 "Cross-process functions are not supported on mobile devices.");
1054 #else // !IS_MOBILE_PLATFORM
1055 const int64_t op_id = ctx.RemoteMgr()->NextOpId();
1056 remote_func_params =
1057 EagerRemoteFunctionParams{op_id, /*step_id=*/absl::nullopt};
1058 #endif // !IS_MOBILE_PLATFORM
1059 }
1060 if (executor.Async()) {
1061 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1062 for (int i = 0, end = num_outputs; i < end; ++i) {
1063 Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
1064 if (output_device == nullptr || output_device->IsLocal()) {
1065 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1066 /* d= */ output_device, /* op_device= */ kernel->device(),
1067 /* resource_device= */ kernel->OutputResourceDevice(i),
1068 output_dtypes[i], &ctx);
1069 } else {
1070 TF_RETURN_IF_ERROR(
1071 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1072 remote_func_params, &ctx, &retvals[i]));
1073 }
1074 }
1075 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1076 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1077 auto node = absl::make_unique<AsyncExecuteNode>(
1078 &ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
1079 op->GetCancellationManager(),
1080 absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
1081 // Release the inputs from the eager operation since the AsyncExecuteNode
1082 // would have taken ownership. This allows the inputs to be forwarded if
1083 // possible.
1084 op->Clear();
1085 // For async mode, execution order will make sure that all
1086 // input handles are ready before executing them.
1087 // TODO(b/137118203): Consider executing "cheap" kernels inline for
1088 // performance.
1089 return executor.AddOrExecute(std::move(node));
1090 } else {
1091 for (int i = 0, end = num_outputs; i < end; ++i) {
1092 retvals[i] = nullptr;
1093 }
1094 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1095 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1096 ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
1097 op->GetCancellationManager(),
1098 {retvals, static_cast<size_t>(num_outputs)},
1099 op->GetStackTrace());
1100 Status s = executor.SyncExecute(&node);
1101 // We release the inputs AFTER executing the operation in sync mode since
1102 // ExecuteNode does not increment the reference count and thus does not have
1103 // ownership of the inputs while executing.
1104 op->Clear();
1105 return s;
1106 }
1107 }
1108
1109 // There are a lot of references to devices in this function and around.
1110 // Here is what they mean:
1111 // EagerOperation::Device(): The device on which the user requested the op
1112 // be executed, except if we had to change the device due to resource inputs
1113 // or CPU pinning. If the user did not request a device, the op does not
1114 // take resources, and we did not pin it to CPU, the device can be nullptr.
1115 // KernelAndDevice::Device(): The first time we see an op (combined with
1116 // its attributes), we need to create a KernelAndDevice object for it.
1117 // If op->Device() is a nullptr, we select a device for the op when
1118 // creating the KernelAndDevice. A concrete device will always be selected
1119 // here except when `op` is a function to be executed using function library
1120 // runtime. In this case, we don't select a device because running
1121 // a function with explicitly requested device has different behavior than
1122 // running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1123 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
1124 int* num_retvals) {
1125 ScopedMemoryDebugAnnotation op_annotation(
1126 op->op_name(), op->remote_func_params().has_value()
1127 ? op->remote_func_params().value().step_id.value_or(0)
1128 : 0);
1129 profiler::TraceMe activity(
1130 [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
1131 profiler::TraceMeLevel::kInfo);
1132 EagerContext& ctx = op->EagerContext();
1133 auto& executor = op->Executor();
1134 TF_RETURN_IF_ERROR(executor.status());
1135
1136 core::RefCountPtr<KernelAndDevice> kernel;
1137 auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1138
1139 // Run all the registered rewrite pass after the placement, regardless whether
1140 // the placement is successful or not. The passes can either create new ops
1141 // (without placement) or update some fields of the input op.
1142 std::unique_ptr<tensorflow::EagerOperation> out_op;
1143 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1144 EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
1145 if (out_op) {
1146 op = out_op.get();
1147 // If the out op doesn't have device, either because it is a new op or
1148 // the op wasn't placed successfully, then we do the placement again.
1149 if (op->Device() == kVariantDeviceNull) {
1150 status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1151 }
1152 }
1153 if (!status.ok()) return status;
1154
1155 int num_outputs = kernel->num_outputs();
1156 TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
1157
1158 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1159 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1160 kernel->device()->name());
1161 if (!logging::LogToListeners(msg)) {
1162 LOG(INFO) << msg;
1163 }
1164 }
1165
1166 Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
1167 // Since the operation failed, we need to Unref any outputs if they were
1168 // allocated.
1169 if (!s.ok()) {
1170 for (int i = 0, end = num_outputs; i < end; ++i) {
1171 if (retvals[i] != nullptr) {
1172 retvals[i]->Unref();
1173 }
1174 }
1175 }
1176
1177 return s;
1178 }
1179
1180 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
1181 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)1182 Status MaybePackInputTensor(EagerOperation* op) {
1183 if (op->is_function()) {
1184 // Functions could take packed TensorHandles as inputs.
1185 return Status::OK();
1186 }
1187 EagerContext& ctx = op->EagerContext();
1188 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1189 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1190 for (int i = 0; i < inputs->size(); ++i) {
1191 TensorHandle* handle = (*inputs)[i];
1192 if (handle->Type() == TensorHandle::PACKED) {
1193 EagerOperation pack_op(&ctx);
1194 TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
1195 /*remote=*/false, /*executor=*/nullptr));
1196 pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
1197 pack_op.MutableAttrs()->Set("T", handle->dtype);
1198 for (int i = 0; i < handle->NumPackedHandles(); ++i) {
1199 tensorflow::TensorHandle* h = nullptr;
1200 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
1201 TF_RETURN_IF_ERROR(pack_op.AddInput(h));
1202 }
1203 int num_retvals = 1;
1204 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
1205 TF_RETURN_IF_ERROR(
1206 EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
1207 tensorflow::TensorHandle* ret = retvals.at(0);
1208 op->UpdateInput(i, ret);
1209 ret->Unref();
1210 }
1211 }
1212 return Status::OK();
1213 }
1214
1215 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)1216 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
1217 EagerContext& ctx = op->EagerContext();
1218
1219 remote_op->set_id(ctx.RemoteMgr()->NextOpId());
1220 remote_op->set_name(op->Name());
1221
1222 op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
1223 remote_op->set_device(absl::get<Device*>(op->Device())->name());
1224 remote_op->set_is_function(op->is_function());
1225 }
1226
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)1227 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
1228 const DataTypeVector& output_dtypes,
1229 TensorHandle** retvals) {
1230 if (remote_op.name() == "VarHandleOp") {
1231 if (output_dtypes.size() != 1) {
1232 return errors::Internal("VarHandleOp should only have one output.");
1233 }
1234 if (output_dtypes[0] != DT_RESOURCE) {
1235 return errors::Internal(
1236 "The output of VarHandleOp should be a DT_RESOURCE.");
1237 }
1238 AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
1239 const AttrValue* dtype;
1240 TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
1241 const AttrValue* shape;
1242 TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
1243 retvals[0]->SetResourceHandleDtypeAndShape(
1244 {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
1245 }
1246 return Status::OK();
1247 }
1248
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1249 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
1250 int* num_retvals) {
1251 EagerContext& ctx = op->EagerContext();
1252
1253 // TODO(fishx): Remove following code when lazy tensor copy is ready.
1254 if (op->Device() == kVariantDeviceNull) {
1255 tensorflow::Device* device = nullptr;
1256 string device_name = op->DeviceName();
1257 TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
1258 op->SetDevice(device);
1259 }
1260
1261 core::RefCountPtr<eager::EagerClient> eager_client;
1262 uint64 context_id = ctx.GetContextId();
1263 TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
1264 string remote_task;
1265 if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
1266 return errors::InvalidArgument(
1267 "Unable to find remote task corresponding to device ",
1268 op->DeviceName());
1269 }
1270
1271 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
1272 request->set_context_id(context_id);
1273
1274 eager::Operation* remote_op = request->add_queue()->mutable_operation();
1275
1276 tensorflow::Device* op_device = absl::get<Device*>(op->Device());
1277 {
1278 profiler::TraceMe activity("CopyInputToExpectedDevice",
1279 profiler::TraceMeLevel::kInfo);
1280 const bool is_function = op->is_function();
1281 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1282 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1283 for (int i = 0, end = inputs->size(); i < end; i++) {
1284 tensorflow::TensorHandle* input = (*inputs)[i];
1285 tensorflow::Device* input_device = input->device();
1286 tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
1287 const string* input_device_name = &input_device_or_cpu->name();
1288 bool serialize_resource_dtype_and_shape = false;
1289 if (op_device != input_device &&
1290 // If the expected and actual devices are on the same task, don't
1291 // explicitly copy, and instead depend on the copy to happen locally
1292 // when the op is executed on the device.
1293 !ctx.OnSameTask(op_device, input_device)) {
1294 if (!is_function || input_device_or_cpu->IsLocal()) {
1295 tensorflow::Device* remote_cpu_device;
1296 TF_RETURN_IF_ERROR(
1297 ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
1298 // Always copy to the remote CPU so that the actual device can be
1299 // correctly determined after the kernel is selected/instantiated,
1300 // since the op might have its inputs on host memory.
1301 TensorHandle* handle = input;
1302 Device* handle_device = handle->DeviceOrHostCPU(ctx);
1303 // If the input is already on the right device, then nothing to do.
1304 if (remote_cpu_device != handle_device) {
1305 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
1306 &ctx, op, op_device, handle, i, handle_device,
1307 remote_cpu_device, &handle));
1308 op->UpdateInput(i, handle);
1309 input = handle;
1310 input_device = remote_cpu_device;
1311 input_device_name = &remote_cpu_device->name();
1312 // Unref handle since it has a ref as an input now
1313 handle->Unref();
1314 }
1315 } else {
1316 serialize_resource_dtype_and_shape =
1317 (input->dtype == DT_RESOURCE) &&
1318 (!input->HasResourceShapeMirror(op_device,
1319 ctx.GetContextViewId()));
1320 }
1321 }
1322 auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
1323 // For a remote component function, a function execution request and an
1324 // input generation request may come from different workers. We need to
1325 // guarantee that the input generation request is processed before the
1326 // function execution request, so wait until the remote input is ready
1327 // before sending it to the multi-device function device.
1328 const bool wait_until_ready = op->is_function();
1329 TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
1330 input, wait_until_ready, input_handle, input_device,
1331 *input_device_name, serialize_resource_dtype_and_shape));
1332 if (!input_handle->resource_dtypes_and_shapes().empty()) {
1333 TF_RETURN_IF_ERROR(
1334 input->AddResourceShapeMirror(op_device, input_handle->op_id(),
1335 input_handle->output_num(), &ctx));
1336 }
1337 }
1338 }
1339
1340 PrepareRemoteOp(remote_op, op);
1341
1342 DataTypeVector output_dtypes;
1343 TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
1344
1345 const size_t num_outputs = output_dtypes.size();
1346 if (num_outputs != *num_retvals) {
1347 return errors::InvalidArgument(
1348 "num_retvals does not match expected output dtypes");
1349 }
1350 *num_retvals = num_outputs;
1351
1352 const tensorflow::uint64 id = remote_op->id();
1353 for (size_t i = 0; i < num_outputs; ++i) {
1354 // TODO(nareshmodi): Change the callback to instead add the decref to a
1355 // list of pending decrefs that we can send as a batch with the next
1356 // execute.
1357
1358 // The device_ and resource_device_ of this TensorHandle might be
1359 // incorrect. For multi-device functions, we don't know the output device
1360 // until the function is instantiated on a remote worker. Luckily, we don't
1361 // need to know the correct remote device here. We just need to know that it
1362 // is remote. If we need copy this tensor to this process or run any ops
1363 // which take this tensor as an input, block until the correct device is
1364 // set.
1365 const bool unknown_device = op->is_function();
1366 retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
1367 id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
1368 }
1369
1370 // Store the data type and shape of a remote resource variable on the
1371 // corresponding remote TensorHandle (output of 'VarHandleOp').
1372 // If the variable is an input of a remote function, the function may need
1373 // the type and shape during function instantiation. Store the type and
1374 // shape on eager master and sent them to the default function device along
1375 // with the EnqueueRequest.
1376 TF_RETURN_IF_ERROR(
1377 StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
1378
1379 auto& executor = op->Executor();
1380 DVLOG(4) << "Execute remote eager op: " << op->Name()
1381 << " (is async?: " << executor.Async() << ").";
1382
1383 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1384 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1385
1386 std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
1387 &op->EagerContext(), std::move(request), op_device,
1388 ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
1389 op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
1390 *inputs, {retvals, num_outputs}));
1391
1392 if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
1393 string msg = strings::StrCat(
1394 "Executing op ", op->Name(), " on task ",
1395 DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
1396 if (!logging::LogToListeners(msg)) {
1397 LOG(INFO) << msg;
1398 }
1399 }
1400
1401 Status s = executor.AddOrExecute(std::move(node));
1402 // Since the operation failed, we need to Unref any outputs that were
1403 // allocated.
1404 if (!s.ok()) {
1405 for (size_t i = 0; i < num_outputs; ++i) {
1406 retvals[i]->Unref();
1407 }
1408 }
1409
1410 return s;
1411 }
1412 #endif // IS_MOBILE_PLATFORM
1413
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerRemoteFunctionParams> & remote_func_params)1414 Status GetKernelOutputs(
1415 std::vector<EagerKernelRet>* outputs, int num_outputs,
1416 TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
1417 const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
1418 for (int i = 0, end = num_outputs; i < end; ++i) {
1419 if (retvals[i] == nullptr) {
1420 EagerKernelRet& ret = (*outputs)[i];
1421 Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
1422 if (ret.index() == 0) {
1423 retvals[i] = TensorHandle::CreateLocalHandle(
1424 std::move(absl::get<Tensor>(ret)),
1425 /* d= */ output_device,
1426 /* op_device= */ kernel->device(),
1427 /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1428 } else {
1429 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1430 TF_RETURN_IF_ERROR(
1431 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1432 remote_func_params, ctx, &retvals[i]));
1433 #if !defined(IS_MOBILE_PLATFORM)
1434 TF_RETURN_IF_ERROR(
1435 retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1436 output_device, ctx->GetContextViewId()));
1437 #endif // IS_MOBILE_PLATFORM
1438 }
1439 } else {
1440 if (!kernel->IsFunction() &&
1441 TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1442 return errors::Internal(
1443 "Kernel output tensor handle has a different op device than the "
1444 "kernel. This should never happen.");
1445 }
1446 if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1447 retvals[i]->device())) {
1448 return errors::Internal(
1449 "Kernel output tensor handle locates on a different device than "
1450 "the specified kernel output device. This should never happen.");
1451 }
1452
1453 EagerKernelRet& ret = (*outputs)[i];
1454 if (ret.index() == 0) {
1455 TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1456 std::move(absl::get<Tensor>(ret)),
1457 ctx->CanonicalDevice(kernel->OutputDevice(i))));
1458 } else {
1459 #if defined(IS_MOBILE_PLATFORM)
1460 return errors::Unimplemented(
1461 "Remote outputs are not available on mobile devices.");
1462 #else // !IS_MOBILE_PLATFORM
1463 TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1464 absl::get<TensorShape>(ret), retvals[i]->device(),
1465 ctx->GetContextViewId()));
1466 #endif // !IS_MOBILE_PLATFORM
1467 }
1468 }
1469 }
1470 return Status::OK();
1471 }
1472
CollectGraphs(EagerContext * ctx)1473 void CollectGraphs(EagerContext* ctx) {
1474 mutex_lock ml(*ctx->MetadataMu());
1475
1476 GraphCollector* collector = ctx->GetGraphCollector();
1477 mutex_lock mll(collector->mu);
1478
1479 // Adding to partition graphs for backward compatibility.
1480 for (const auto& graph : collector->partitioned_graphs) {
1481 *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1482 }
1483
1484 if (collector->dirty) {
1485 auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1486 *function_graphs->mutable_post_optimization_graph() =
1487 collector->optimized_graph;
1488 *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1489 for (const auto& graph : collector->partitioned_graphs) {
1490 *function_graphs->add_partition_graphs() = graph;
1491 }
1492 }
1493
1494 collector->ClearGraphs();
1495 }
1496 } // namespace
1497
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1498 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1499 int* num_retvals) {
1500 profiler::TraceMe activity(
1501 [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
1502 profiler::TraceMeLevel::kInfo);
1503
1504 if (!op->Executor().Async()) {
1505 // In sync mode, always clear error to maintain the same behavior as before.
1506 // TODO(b/141004939): Remove this.
1507 op->Executor().ClearError();
1508 }
1509
1510 std::unique_ptr<tensorflow::EagerOperation> out_op;
1511 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1512 EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1513
1514 if (op->IsLocal()) {
1515 if (out_op) {
1516 op = out_op.get();
1517 }
1518 TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1519 return EagerLocalExecute(op, retvals, num_retvals);
1520 }
1521
1522 #if defined(IS_MOBILE_PLATFORM)
1523 return errors::Unimplemented(
1524 "Eager's remote execution is not available on mobile devices.");
1525 #else // !IS_MOBILE_PLATFORM
1526 if (out_op) {
1527 op = out_op.get();
1528 }
1529 return EagerRemoteExecute(op, retvals, num_retvals);
1530 #endif // !IS_MOBILE_PLATFORM
1531 }
1532
1533 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1534 Status EagerKernelExecute(
1535 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1536 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1537 const core::RefCountPtr<KernelAndDevice>& kernel,
1538 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1539 absl::Span<TensorHandle*> retvals,
1540 const absl::optional<ManagedStackTrace>& stack_trace) {
1541 profiler::TraceMe activity("EagerKernelExecute",
1542 profiler::TraceMeLevel::kInfo);
1543 std::vector<EagerKernelRet> outputs(1);
1544
1545 ExecuteNodeArgs inputs(op_inputs.size());
1546 TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1547 // TODO(apassos) figure out how to record stats for ops which are a part of
1548 // functions.
1549 // TODO(b/111859745): When we support recovering from kernel/device errors, we
1550 // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1551 // device. We don't call it now because it is an unneeded overhead (it
1552 // acquires a lock) and we can't recover from errors anyway.
1553 ScopedStepContainer* container = ctx->StepContainer();
1554 CoordinationServiceAgent* coord_agent = nullptr;
1555 #if !defined(IS_MOBILE_PLATFORM)
1556 if (ctx->GetDistributedManager() != nullptr)
1557 coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1558 #endif // !IS_MOBILE_PLATFORM
1559 TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1560 cancellation_manager, remote_func_params,
1561 stack_trace, coord_agent));
1562 if (graph_collector != nullptr) {
1563 CollectGraphs(ctx);
1564 }
1565
1566 if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1567 return errors::Internal(
1568 "EagerKernelExecute returns a list of ", outputs.size(),
1569 " tensors but ", retvals.size(),
1570 " is expected. This should never "
1571 "happen. Please file a bug with the TensorFlow team.");
1572 }
1573 return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1574 kernel.get(), remote_func_params);
1575 }
1576
1577 namespace {
1578
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1579 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1580 EagerExecutor* executor, Device* dstd,
1581 bool mirror, TensorHandle** result) {
1582 TF_RETURN_IF_ERROR(executor->status());
1583 Device* d = ctx->CanonicalDevice(dstd);
1584 if (mirror && h->HasLocalMirror(d)) {
1585 h->Ref();
1586 *result = h;
1587 return Status::OK();
1588 }
1589
1590 bool async = executor->Async();
1591 if (mirror) {
1592 h->Ref();
1593 *result = h;
1594
1595 if (h->HasLocalMirror(d)) {
1596 return Status::OK();
1597 }
1598
1599 // We don't bother adding an empty local mirror in sync mode since we'll be
1600 // executing the operation directly and be calling AddLocalMirror. A
1601 // reference count is still needed which will be removed if the operation
1602 // fails.
1603 if (async) {
1604 Status s = h->AddEmptyLocalMirror(d);
1605 if (!s.ok()) {
1606 // If a mirror was added since we called HasLocalMirror then just return
1607 // since another thread has already added the mirror.
1608 if (s.code() == error::Code::ALREADY_EXISTS) {
1609 return Status::OK();
1610 }
1611
1612 // Remove the previously added reference count since adding the mirror
1613 // failed.
1614 h->Unref();
1615 *result = nullptr;
1616 return s;
1617 }
1618 }
1619 } else {
1620 *result = TensorHandle::CreateEmptyLocalHandle(
1621 d, dstd, h->resource_device(), h->dtype, ctx);
1622 }
1623
1624 Status s;
1625 if (async) {
1626 // Note that `h` may not be currently ready. However execution order will
1627 // make sure that `h` is ready before the copy is actually done.
1628 std::unique_ptr<EagerNode> node(
1629 new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1630 s = executor->AddOrExecute(std::move(node));
1631 } else {
1632 CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1633 s = executor->SyncExecute(&node);
1634 }
1635
1636 // Since the operation failed, we need to Unref any outputs that were
1637 // allocated.
1638 if (!s.ok()) {
1639 (*result)->Unref();
1640 }
1641
1642 return s;
1643 }
1644
1645 } // namespace
1646
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1647 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1648 EagerExecutor* executor, Device* device, bool mirror,
1649 TensorHandle** result) {
1650 TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1651 auto send_device = h->DeviceOrHostCPU(*ctx);
1652 bool sender_is_local = send_device->IsLocal();
1653
1654 bool receiver_is_local = device->IsLocal();
1655
1656 if (!executor->Async()) {
1657 // In sync mode, always clear error to maintain the same behavior as before.
1658 // TODO(b/141004939): Remove this.
1659 executor->ClearError();
1660 }
1661
1662 if (sender_is_local && receiver_is_local) {
1663 return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1664 } else {
1665 #if defined(IS_MOBILE_PLATFORM)
1666 return errors::Unimplemented(
1667 "Eager's remote execution is not available on mobile devices.");
1668 #else // !IS_MOBILE_PLATFORM
1669 uint64 recv_op_id = 0;
1670 if (receiver_is_local) {
1671 Device* d = ctx->CanonicalDevice(device);
1672 // TODO(gjn): Need to add support for async execution. Note if receiver
1673 // is local, we need to first add support in TensorHandle to wait on local
1674 // mirrors.
1675 if (mirror) {
1676 h->Ref();
1677 *result = h;
1678
1679 if (h->HasLocalMirror(d)) {
1680 return Status::OK();
1681 }
1682
1683 Status s = h->AddEmptyLocalMirror(d);
1684 if (!s.ok()) {
1685 // If a mirror was added since we called HasLocalMirror then just
1686 // return since another thread has already added the mirror.
1687 if (s.code() == error::Code::ALREADY_EXISTS) {
1688 return Status::OK();
1689 }
1690
1691 // Remove the previously added reference count since adding the mirror
1692 // failed.
1693 h->Unref();
1694 *result = nullptr;
1695 return s;
1696 }
1697 } else {
1698 *result = TensorHandle::CreateEmptyLocalHandle(
1699 /* d= */ d, /* op_device= */ device,
1700 /*resource_device=*/nullptr, h->dtype, ctx);
1701 }
1702 } else {
1703 if (mirror) {
1704 if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1705 h->Ref();
1706 *result = h;
1707 return Status::OK();
1708 }
1709 }
1710 string remote_task;
1711 if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1712 return errors::InvalidArgument(
1713 "Unable to find remote task corresponding to device ",
1714 device->name());
1715 }
1716 recv_op_id = ctx->RemoteMgr()->NextOpId();
1717 if (mirror) {
1718 TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
1719 remote_task, ctx));
1720 h->Ref();
1721 *result = h;
1722 } else {
1723 *result = TensorHandle::CreateUnshapedRemoteHandle(
1724 recv_op_id, 0, remote_task, h->dtype, device, ctx);
1725 }
1726 }
1727
1728 auto node = std::make_unique<eager::RemoteCopyNode>(
1729 ctx, executor, h, result[0], device, recv_op_id);
1730 Status s = executor->AddOrExecute(std::move(node));
1731 if (!s.ok()) {
1732 result[0]->Unref();
1733 }
1734 return s;
1735 #endif // !IS_MOBILE_PLATFORM
1736 }
1737 }
1738
1739 namespace {
1740 // Low-level utility function to execute the kernel specified by `kernel` on
1741 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
1742 // Different from `EagerKernelExecute` that ties up the thread until the
1743 // underlying function finishes execute, this function does not block the thread
1744 // and could return before the function execution finishes. The provided
1745 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)1746 void EagerKernelExecuteAsync(
1747 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1748 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1749 const core::RefCountPtr<KernelAndDevice> kernel,
1750 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1751 TensorHandle** retvals, int num_outputs, StatusCallback done) {
1752 auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
1753 auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
1754
1755 Status s = inputs->Init(ctx, op_inputs, kernel);
1756 if (!s.ok()) {
1757 done(s);
1758 return;
1759 }
1760 CoordinationServiceAgent* coord_agent = nullptr;
1761 #if !defined(IS_MOBILE_PLATFORM)
1762 if (ctx->GetDistributedManager() != nullptr)
1763 coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1764 #endif // !IS_MOBILE_PLATFORM
1765
1766 kernel->Ref(); // Ownership of reference is transferred to the callback
1767 kernel->RunAsync(
1768 ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
1769 remote_func_params, coord_agent,
1770 [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
1771 remote_func_params, kernel_raw = kernel.get(),
1772 done = std::move(done)](const Status& s) {
1773 auto wrapped_done = [&](const Status& s) {
1774 kernel_raw->Unref();
1775 done(s);
1776 };
1777 if (!s.ok()) {
1778 wrapped_done(s);
1779 return;
1780 }
1781 if (graph_collector != nullptr) {
1782 CollectGraphs(ctx);
1783 }
1784 DCHECK_EQ(num_outputs, outputs->size());
1785 wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
1786 kernel_raw, remote_func_params));
1787 });
1788 }
1789 } // namespace
1790
1791 // Low-level utility to run the eager operation on local devices. Different from
1792 // `EagerLocalExecute` which blocks and waits for the finishing the op
1793 // execution, this method does not block the thread and could return before the
1794 // eager operation execution finishes. The provided `StatusCallback` will be
1795 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)1796 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
1797 int* num_retvals, StatusCallback done) {
1798 if (!op->IsLocal()) {
1799 done(errors::InvalidArgument(
1800 "Remote execution is not supported in async EagerLocalExecuteAsync"));
1801 return;
1802 }
1803
1804 ScopedMemoryDebugAnnotation op_annotation(
1805 op->op_name(), op->remote_func_params().has_value()
1806 ? op->remote_func_params().value().step_id.value_or(0)
1807 : 0);
1808 profiler::TraceMe activity(
1809 [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
1810 profiler::TraceMeLevel::kInfo);
1811 EagerContext& ctx = op->EagerContext();
1812
1813 core::RefCountPtr<KernelAndDevice> kernel;
1814 Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1815 if (!s.ok()) {
1816 done(s);
1817 return;
1818 }
1819
1820 int num_outputs = kernel->num_outputs();
1821 s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
1822 if (!s.ok()) {
1823 done(s);
1824 return;
1825 }
1826
1827 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1828 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1829 kernel->device()->name());
1830 if (!logging::LogToListeners(msg)) {
1831 LOG(INFO) << msg;
1832 }
1833 }
1834
1835 GraphCollector* graph_collector = nullptr;
1836 if (ctx.ShouldStoreGraphs()) {
1837 graph_collector = ctx.GetGraphCollector();
1838 }
1839
1840 for (int i = 0, end = num_outputs; i < end; ++i) {
1841 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1842 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1843 /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
1844 /* op_device= */ kernel->device(),
1845 /* resource_device= */ kernel->OutputResourceDevice(i),
1846 output_dtypes[i], &ctx);
1847 }
1848
1849 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1850 s = op->TensorHandleInputs(&inputs);
1851 if (!s.ok()) {
1852 done(s);
1853 return;
1854 }
1855 EagerKernelExecuteAsync(
1856 &ctx, *inputs, op->remote_func_params(), std::move(kernel),
1857 graph_collector, op->GetCancellationManager(), retvals, num_outputs,
1858 [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
1859 op->Clear();
1860 // Since the operation failed, we need to Unref any outputs if they were
1861 // allocated.
1862 if (!s.ok()) {
1863 for (int i = 0, end = num_outputs; i < end; ++i) {
1864 if (retvals[i] != nullptr) {
1865 retvals[i]->Unref();
1866 }
1867 }
1868 }
1869 done(s);
1870 });
1871 }
1872 } // namespace tensorflow
1873