1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17
18 #include <cstddef>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/platform/platform.h"
30 // clang-format on
31
32 #include "absl/container/inlined_vector.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/compiler/jit/defs.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_set.h"
40 #include "tensorflow/core/common_runtime/eager/context.h"
41 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
42 #include "tensorflow/core/common_runtime/eager/execute_node.h"
43 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
44 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
45 #include "tensorflow/core/common_runtime/colocation_graph.h"
46 #include "tensorflow/core/framework/dataset.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/logging.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/tensor_reference.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/error_codes.pb.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #if !defined(IS_MOBILE_PLATFORM)
57 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
58 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
59 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
60 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
61 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
62 #endif // IS_MOBILE_PLATFORM
63 #include "tensorflow/core/framework/step_stats.pb.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/types.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/lib/gtl/cleanup.h"
68 #include "tensorflow/core/lib/gtl/flatset.h"
69 #include "tensorflow/core/lib/random/random.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/mutex.h"
72 #include "tensorflow/core/util/ptr_util.h"
73 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
74
75 namespace tensorflow {
76
77 namespace {
78
DeviceNameOrUnspecified(Device * device)79 const string& DeviceNameOrUnspecified(Device* device) {
80 static string* unspecified_string = new string("<unspecified>");
81 return (device == nullptr) ? *unspecified_string : device->name();
82 }
83
84 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)85 bool KernelCacheEnabled(const OpDef& op_def) {
86 if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
87 return false;
88 }
89 // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
90 // 5+ kernels to exclude.
91 return true;
92 }
93
94 // This function expects *handle to point to an existing tensor handle that is
95 // currently on "handle_device", but where the operation expects that input to
96 // reside on "expected_input_device". The function will arrange for this
97 // transfer to happen and will return OK on success and will storage a new
98 // handle to the equivalent tensor on the correct device in "*result". Or if an
99 // error is encountered, it will return a non-OK status and set "*result" to
100 // nullptr.
101 //
102 // `op_device` is passed in explicitly because `op->device()` might be
103 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)104 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
105 Device* op_device,
106 TensorHandle* handle, // op->Inputs()[i]
107 int i, Device* handle_device,
108 Device* expected_input_device,
109 TensorHandle** result) {
110 // Should only be called when these don't match
111 DCHECK(expected_input_device != handle_device);
112 *result = nullptr;
113 const string& op_device_name = DeviceNameOrUnspecified(op_device);
114
115 switch (ctx->GetDevicePlacementPolicy()) {
116 case DEVICE_PLACEMENT_SILENT_FOR_INT32:
117 // TODO(xpan): See if we could bubble python related error up
118 // to python level.
119 if (handle->dtype == DT_INT32) {
120 // Note: enabling silent copies of int32 tensors to match behavior
121 // of graph mode.
122 break;
123 }
124 TF_FALLTHROUGH_INTENDED;
125 case DEVICE_PLACEMENT_EXPLICIT:
126 // tf.identity is allowed to copy, as indicated in the error message
127 // below.
128 if (op->Name() == "Identity" || op->Name() == "IdentityN") {
129 break;
130 }
131 return errors::InvalidArgument(
132 "Tensors on conflicting devices:"
133 " cannot compute ",
134 op->Name(), " as input #", i, " was expected to be on ",
135 expected_input_device->name(), " but is actually on ",
136 handle_device->name(), " (operation running on ", op_device_name, ")",
137 " Tensors can be copied explicitly using:"
138 " `with tf.device(device_name): x = tf.identity(x)`"
139 " or transparently copied by using"
140 " tf.config.experimental.set_device_policy('silent')."
141 " Copying tensors between devices may slow down your model");
142 case DEVICE_PLACEMENT_WARN:
143 LOG(WARNING) << "before computing " << op->Name() << " input #" << i
144 << " was expected to be on " << expected_input_device->name()
145 << " but is actually on " << handle_device->name()
146 << " (operation running on " << op_device_name
147 << "). This triggers a copy which can be a performance "
148 "bottleneck.";
149 break;
150 case DEVICE_PLACEMENT_SILENT: // Do nothing.
151 break;
152 }
153 // We are only here if the policy is warn or silent copies, so we should
154 // trigger a copy.
155 TensorHandle* result_handle = nullptr;
156 profiler::TraceMe activity(
157 [&] {
158 return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
159 " to ", expected_input_device->name());
160 },
161 profiler::TraceMeLevel::kInfo);
162 Status status =
163 EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
164 /* mirror= */ true, &result_handle);
165 activity.Stop();
166 if (!status.ok()) {
167 return Status(
168 status.code(),
169 absl::StrCat("Failed copying input tensor from ", handle_device->name(),
170 " to ", expected_input_device->name(), " in order to run ",
171 op->Name(), ": ", status.error_message()));
172 }
173
174 *result = result_handle;
175
176 return Status::OK();
177 }
178
179 // `op_device_name` the name of the device on which the op will run, if any.
180 // For functions running using function library runtime, the device can be
181 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)182 Status ValidateInputTypeAndPlacement(
183 EagerContext* ctx, EagerOperation* op,
184 const core::RefCountPtr<KernelAndDevice>& kernel) {
185 profiler::TraceMe activity("ValidateInputTypeAndPlacement",
186 profiler::TraceMeLevel::kInfo);
187 const int n_inputs = op->Inputs().size();
188 if (kernel->num_inputs() != n_inputs) {
189 return errors::InvalidArgument("expected ", kernel->num_inputs(),
190 " inputs, got ", n_inputs);
191 }
192 const bool is_function = kernel->IsFunction();
193 if (n_inputs > 0) {
194 const DataType* input_types = &kernel->input_dtypes()[0];
195 const absl::InlinedVector<TensorHandle*, 4>* handles;
196 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
197 for (int i = 0; i < n_inputs; ++i) {
198 TensorHandle* handle = (*handles)[i];
199 Device* expected_device = kernel->InputDevice(i);
200 if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
201 // Extract a handle on the op device from a packed input.
202 // This happens when a function is marked for XLA compilation.
203 // MaybePackInputTensor guarantees that a primitive op has no packed
204 // input at this point.
205 for (int j = 0; j < handle->NumPackedHandles(); ++j) {
206 TensorHandle* h = nullptr;
207 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
208 if ((h->op_device() != nullptr) &&
209 (h->op_device()->name() == op->DeviceName())) {
210 op->UpdateInput(i, h);
211 handle = h;
212 break;
213 }
214 }
215 }
216 Device* handle_device = handle->DeviceOrHostCPU(*ctx);
217 const bool maybe_copy =
218 !is_function || handle->Type() != TensorHandle::REMOTE;
219 // If the input is already on the right device, then nothing to do.
220 if (expected_device != handle_device && maybe_copy) {
221 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
222 handle, i, handle_device,
223 expected_device, &handle));
224 op->UpdateInput(i, handle);
225 // Unref handle since it has a ref as an input now
226 handle->Unref();
227 }
228 if (handle->dtype != input_types[i]) {
229 return errors::InvalidArgument(
230 "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
231 " was expected to be a ", DataTypeString(input_types[i]),
232 " tensor but is a ", DataTypeString(handle->dtype), " tensor");
233 }
234 }
235 }
236 return Status::OK();
237 }
238
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)239 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
240 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
241 const OpDef* op_def = nullptr;
242
243 const FunctionDef* function_def =
244 op->EagerContext().FuncLibDef()->Find(op->Name());
245 if (function_def != nullptr) {
246 op_def = &(function_def->signature());
247 } else {
248 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
249 }
250
251 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
252
253 return Status::OK();
254 }
255
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)256 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
257 const tensorflow::Fprint128& b) {
258 return {tensorflow::FingerprintCat64(a.low64, b.low64),
259 tensorflow::FingerprintCat64(a.high64, b.high64)};
260 }
261
FingerprintCat128(const tensorflow::Fprint128 & a,const int64 b)262 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
263 const int64 b) {
264 auto x = tensorflow::FingerprintCat64(a.low64, b);
265 return {x, tensorflow::FingerprintCat64(a.high64, x)};
266 }
267
GetDeviceForInput(const EagerContext & ctx,TensorHandle * tensor_handle,Device ** result)268 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
269 Device** result) {
270 Device* cpu_device = ctx.HostCPU();
271 string device_name;
272 if (tensor_handle->Type() != TensorHandle::LOCAL) {
273 Device* device = tensor_handle->device();
274 device_name = device != nullptr ? device->name() : cpu_device->name();
275 *result = (device == nullptr ? cpu_device : device);
276 } else if (tensor_handle->dtype == DT_RESOURCE) {
277 // Use the resource's actual device because it is the device that will
278 // influence partitioning the multi-device function.
279 const Tensor* tensor;
280 // TODO(fishx): Avoid blocking here.
281 TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
282 const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
283 device_name = handle.device();
284
285 Device* input_device;
286 TF_RETURN_IF_ERROR(
287 ctx.FindDeviceFromName(device_name.c_str(), &input_device));
288 *result = input_device;
289 } else {
290 Device* device = tensor_handle->device();
291 const bool is_tpu = device != nullptr && device->device_type() == "TPU";
292 // int32 return values can be placed on TPUs.
293 const bool use_host_memory =
294 is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
295 : MTypeFromDType(tensor_handle->dtype);
296 if (use_host_memory) {
297 *result = cpu_device;
298 } else {
299 device_name = device != nullptr ? device->name() : cpu_device->name();
300 *result = (device == nullptr ? cpu_device : device);
301 }
302 }
303 return Status::OK();
304 }
305
306 // Appends a TensorShape object to Fprint128 hash.
307 // For best performance, we would like to avoid dynamic memory allocation in
308 // this function.
309 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
310 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)311 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
312 Fprint128* fingerprint) {
313 if (shape.unknown_rank()) {
314 char c = '?';
315 *fingerprint = FingerprintCat128(*fingerprint, c);
316 } else {
317 for (int i = 0; i < shape.dims(); i++) {
318 int64 dim = shape.dim_size(i);
319 *fingerprint = FingerprintCat128(*fingerprint, dim);
320 }
321 }
322 }
323
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)324 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
325 const char* attr_name, bool* value) {
326 Status status = op->Attrs().Get(attr_name, value);
327 if (status.ok()) {
328 DVLOG(2) << "Caller explicitly specifies "
329 << (attr_name ? "=true " : "=false, ") << op->DebugString();
330 return Status::OK();
331 }
332
333 const FunctionDef* function_def =
334 ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
335 if (function_def == nullptr) {
336 return errors::NotFound("Failed to find function '", op->Name(), "'");
337 }
338
339 status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
340 if (status.ok()) {
341 DVLOG(2) << "Function definition explicitly specifies "
342 << (attr_name ? "=true" : "=false");
343 return Status::OK();
344 }
345 return status;
346 }
347
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)348 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
349 bool* compile_with_xla) {
350 if (!op->is_function()) {
351 *compile_with_xla = false;
352 return Status::OK();
353 }
354
355 if (op->remote_func_params().has_value() &&
356 op->remote_func_params().value().step_id.has_value()) {
357 // If the op is a component of a multi-device function, don't compile it
358 // with XLA.
359 *compile_with_xla = false;
360 return Status::OK();
361 }
362
363 Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
364 if (status.ok()) {
365 return Status::OK();
366 }
367
368 // No explicit requests. Compile for XLA devices by default.
369 if (op->GetDeviceParsedName().type == "TPU" ||
370 op->GetDeviceParsedName().type == "XLA_GPU" ||
371 op->GetDeviceParsedName().type == "XLA_CPU") {
372 DVLOG(2) << "Compiling " << op->Name()
373 << " with XLA because it is running on an XLA device "
374 << op->GetDeviceParsedName().type;
375 *compile_with_xla = true;
376 } else {
377 *compile_with_xla = false;
378 }
379
380 return Status::OK();
381 }
382
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)383 Status GetOrCreateKernelAndDevice(
384 EagerOperation* op, TensorHandle** retvals, int* num_retvals,
385 core::RefCountPtr<KernelAndDevice>* out_kernel) {
386 EagerContext& ctx = op->EagerContext();
387 Device* device = absl::get<Device*>(op->Device());
388
389 Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
390 /// Include soft placement policy in cache key since the placement strategy
391 // can change and thus affect which kernel is picked.
392 cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
393 // The launch-time rendezvous reuse setting is bundled with the kernel, so we
394 // need to include it in the cache key.
395 cache_key =
396 FingerprintCat128(cache_key, ctx.GetReuseRendezvousForFunctions());
397
398 std::vector<Device*> input_dev_ptrs;
399 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
400 std::unordered_map<int, DtypeAndPartialTensorShape>
401 input_resource_variable_dtypes_and_shapes;
402 // We can eliminate some overhead by running simple functions using regular
403 // CallOp kernel. However, it is tricky to figure out which functions should
404 // be run using CallOp. Also, currently CallOp runs neither optimization
405 // passes (needed for TPU/XLA) nor grappler.
406 // Here are some cases where a function should be run in multi-device mode:
407 // - Function takes at least two resources on different devices.
408 // - Function takes a resource on deviceA and a body op explicitly placed
409 // on deviceB.
410 // - Function has a colocation constraint.
411 // - Function has an explicit device annotation (which might not be using
412 // full canonical device name) different from op_device. Note that false
413 // positives are ok.
414 // - Function has a node or a (node) attribute that can potentially make
415 // the function multi-device after a rewrite pass (e.g. various XLA/TPU
416 // special nodes and attributes)
417 if (op->is_function()) {
418 profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
419 profiler::TraceMeLevel::kInfo);
420 input_dev_ptrs.reserve(op->Inputs().size());
421 const absl::InlinedVector<TensorHandle*, 4>* inputs;
422 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
423 for (int i = 0, end = inputs->size(); i < end; i++) {
424 TensorHandle* input = (*inputs)[i];
425
426 // Get device for this input, and add it to 'cache_key'.
427 Device* input_device;
428 TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
429 input_dev_ptrs.push_back(input_device);
430 CompositeDevice* composite_device = nullptr;
431 if (ctx.FindCompositeDeviceFromName(input_device->name(),
432 &composite_device)
433 .ok()) {
434 composite_devices[input_device->name()] =
435 composite_device->underlying_devices();
436 }
437 cache_key =
438 FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
439
440 // If input is a ResourceHandle, get its resource handle dtypes and shapes
441 // and add them to 'cache_key'.
442 if (input->dtype == DT_RESOURCE) {
443 // We only care about data type and shape for resource variable inputs.
444 // But we have no way to tell if input is resource variable (other than
445 // looking it up in ResourceMgr, which is slow). So we just get
446 // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
447 // resource_dtypes_and_shapes is not empty, take the first element.
448 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
449 TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
450 &resource_dtypes_and_shapes));
451 if (!resource_dtypes_and_shapes.empty()) {
452 const DtypeAndPartialTensorShape& dtype_and_shape =
453 resource_dtypes_and_shapes.at(0);
454 input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
455
456 // Add _Arg index, dtype and shape to "cache_key".
457 cache_key = FingerprintCat128(cache_key, i);
458 DataType dtype = dtype_and_shape.dtype;
459 cache_key = FingerprintCat128(cache_key, dtype);
460 AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
461 }
462 }
463 }
464 }
465
466 core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
467 if (kernel == nullptr) {
468 DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
469 << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
470 bool run_function_with_flr = false;
471 bool function_outputs_on_op_device = false;
472 if (op->is_function()) {
473 bool compile_with_xla;
474 TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
475 if (compile_with_xla) {
476 // Note that it is not ideal, but currently correct, to set this
477 // attribute after computing the kernel cache key above.
478 // Note: If the attribute is already set to true, this is a noop.
479 op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
480 } else {
481 run_function_with_flr = true;
482 }
483 GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
484 .IgnoreError();
485 }
486
487 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
488 if (device == nullptr) {
489 TF_RETURN_IF_ERROR(
490 ctx.SelectDevice(op->GetDeviceParsedName(), ndef, &device));
491
492 DVLOG(1) << "Placer place op [" << op->Name()
493 << "] on device: " << device->name();
494 DVLOG(4) << "Available kernels for " << op->Name() << "are "
495 << KernelsRegisteredForOp(op->Name());
496 op->SetDevice(device);
497 }
498
499 FunctionLibraryRuntime* flr =
500 device == nullptr ? nullptr : ctx.func_lib(device);
501 if (device != nullptr && flr == nullptr) {
502 return errors::Unavailable(
503 "Unable to find a FunctionLibraryRuntime corresponding to device ",
504 device->name());
505 }
506 auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
507 : ctx.runner();
508 GraphCollector* graph_collector = nullptr;
509 if (ctx.ShouldStoreGraphs()) {
510 graph_collector = ctx.GetGraphCollector();
511 }
512 // Treat the function as multi_device only when we are not compiling
513 // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
514 // will create an XlaLaunchOp kernel to compile and run the function.
515 if (run_function_with_flr) {
516 // Multi-device functions don't use the rendezvous from eager context.
517 // If we use that rendezvous, multiple concurrent calls to the same
518 // function will likely result in collisions. However, this also means
519 // that we don't support legitimate sending/receiving across function
520 // boundary.
521 DVLOG(2) << "Running " << ndef.op() << " using multi-device function. "
522 << "Full node_def=" << ndef.DebugString();
523 std::function<int64()> get_op_id = nullptr;
524 #if !defined(IS_MOBILE_PLATFORM)
525 get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
526 #endif // IS_MOBILE_PLATFORM
527 kernel.reset(new KernelAndDeviceFunc(
528 flr, ctx.pflr(), std::move(input_dev_ptrs),
529 std::move(composite_devices),
530 std::move(input_resource_variable_dtypes_and_shapes), runner,
531 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
532 function_outputs_on_op_device, ctx.RendezvousCreator(), get_op_id));
533 } else {
534 DVLOG(2) << "Running " << ndef.op() << " using op kernel. "
535 << ". Full node_def=" << ndef.DebugString();
536 kernel.reset(new KernelAndDeviceOp(
537 ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
538 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
539 }
540
541 TF_RETURN_IF_ERROR(
542 kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
543
544 if (op->is_function()) {
545 ctx.AddKernelToCache(cache_key, kernel.get());
546 } else {
547 // Exclude tf.data op kernels from being cached. The reason for this is
548 // that tf.data op kernels that accept a user-defined function will have a
549 // unique cache key every time they are executed (because the user-defined
550 // function is traced every time). Caching such kernels provides no
551 // benefit and in some cases results in linear memory growth of use
552 // programs that build input pipeline graphs in a loop.
553 const OpDef* op_def;
554 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
555 if (KernelCacheEnabled(*op_def)) {
556 ctx.AddKernelToCache(cache_key, kernel.get());
557 }
558 }
559 }
560
561 int num_outputs = kernel->num_outputs();
562 if (num_outputs > *num_retvals) {
563 return errors::InvalidArgument("Expecting ", num_outputs,
564 " outputs, but *num_retvals is ",
565 *num_retvals);
566 }
567 *num_retvals = num_outputs;
568
569 kernel->Ref(); // Ownership of reference is passed to out_kernel.
570 out_kernel->reset(kernel.get());
571 return Status::OK();
572 }
573
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,EagerContext * ctx,TensorHandle ** output)574 Status CreateUnshapedOutput(
575 const KernelAndDevice& kernel, const int output_num, Device* output_device,
576 const DataType& output_dtype,
577 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
578 EagerContext* ctx, TensorHandle** output) {
579 #if defined(IS_MOBILE_PLATFORM)
580 return errors::Unimplemented(
581 "Remote outputs are not available on mobile devices.");
582 #else // !IS_MOBILE_PLATFORM
583 int64 op_id;
584 if (remote_func_params.has_value()) {
585 op_id = remote_func_params.value().op_id;
586 } else {
587 return errors::InvalidArgument(
588 "Unable to find a remote op id for a remote output of ", kernel.name());
589 }
590 string remote_task;
591 if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
592 &remote_task)) {
593 return errors::InvalidArgument(
594 "Unable to find remote task corresponding to device ",
595 output_device->name());
596 }
597 if (ctx->RemoteMgr()->IsMaster()) {
598 *output = TensorHandle::CreateUnshapedRemoteHandle(
599 op_id, output_num, remote_task, output_dtype, output_device, ctx);
600 } else {
601 *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
602 output_dtype, output_device,
603 /*is_ready=*/false, ctx);
604 }
605 return Status::OK();
606 #endif // !IS_MOBILE_PLATFORM
607 }
608
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)609 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
610 EagerOperation* op, TensorHandle** retvals) {
611 EagerExecutor& executor = op->Executor();
612 EagerContext& ctx = op->EagerContext();
613 GraphCollector* graph_collector = nullptr;
614 if (ctx.ShouldStoreGraphs()) {
615 graph_collector = ctx.GetGraphCollector();
616 }
617 const int num_outputs = kernel->num_outputs();
618 absl::optional<EagerRemoteFunctionParams> remote_func_params =
619 op->remote_func_params();
620 if (kernel->IsCrossProcess() && !remote_func_params.has_value()) {
621 // Create an eager op id for a cross-process function if not exist.
622 #if defined(IS_MOBILE_PLATFORM)
623 return errors::Unimplemented(
624 "Cross-process functions are not supported on mobile devices.");
625 #else // !IS_MOBILE_PLATFORM
626 const int64 op_id = ctx.RemoteMgr()->NextOpId();
627 remote_func_params =
628 EagerRemoteFunctionParams{op_id, /*step_id=*/absl::nullopt};
629 #endif // !IS_MOBILE_PLATFORM
630 }
631 if (executor.Async()) {
632 const DataTypeVector& output_dtypes = kernel->output_dtypes();
633 for (int i = 0, end = num_outputs; i < end; ++i) {
634 Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
635 if (output_device == nullptr || output_device->IsLocal()) {
636 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
637 /* d= */ output_device, /* op_device= */ kernel->device(),
638 /* resource_device= */ kernel->OutputResourceDevice(i),
639 output_dtypes[i], &ctx);
640 } else {
641 TF_RETURN_IF_ERROR(
642 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
643 remote_func_params, &ctx, &retvals[i]));
644 }
645 }
646 const absl::InlinedVector<TensorHandle*, 4>* inputs;
647 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
648 auto node = absl::make_unique<AsyncExecuteNode>(
649 &ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
650 op->GetCancellationManager(),
651 absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
652 // Release the inputs from the eager operation since the AsyncExecuteNode
653 // would have taken ownership. This allows the inputs to be forwarded if
654 // possible.
655 op->Clear();
656 // For async mode, execution order will make sure that all
657 // input handles are ready before executing them.
658 // TODO(b/137118203): Consider executing "cheap" kernels inline for
659 // performance.
660 return executor.AddOrExecute(std::move(node));
661 } else {
662 for (int i = 0, end = num_outputs; i < end; ++i) {
663 retvals[i] = nullptr;
664 }
665 const absl::InlinedVector<TensorHandle*, 4>* inputs;
666 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
667 ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
668 op->GetCancellationManager(),
669 {retvals, static_cast<size_t>(num_outputs)},
670 op->GetStackTrace());
671 Status s = executor.SyncExecute(&node);
672 // We release the inputs AFTER executing the operation in sync mode since
673 // ExecuteNode does not increment the reference count and thus does not have
674 // ownership of the inputs while executing.
675 op->Clear();
676 return s;
677 }
678 }
679
680 // There are a lot of references to devices in this function and around.
681 // Here is what they mean:
682 // EagerOperation::Device(): The device on which the user requested the op
683 // be executed, except if we had to change the device due to resource inputs
684 // or CPU pinning. If the user did not request a device, the op does not
685 // take resources, and we did not pin it to CPU, the device can be nullptr.
686 // KernelAndDevice::Device(): The first time we see an op (combined with
687 // its attributes), we need to create a KernelAndDevice object for it.
688 // If op->Device() is a nullptr, we select a device for the op when
689 // creating the KernelAndDevice. A concrete device will always be selected
690 // here except when `op` is a function to be executed using function library
691 // runtime. In this case, we don't select a device because running
692 // a function with explicitly requested device has different behavior than
693 // running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)694 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
695 int* num_retvals) {
696 ScopedMemoryDebugAnnotation op_annotation(
697 op->op_name(), op->remote_func_params().has_value()
698 ? op->remote_func_params().value().step_id.value_or(0)
699 : 0);
700 profiler::TraceMe activity(
701 [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
702 profiler::TraceMeLevel::kInfo);
703 EagerContext& ctx = op->EagerContext();
704 auto& executor = op->Executor();
705 TF_RETURN_IF_ERROR(executor.status());
706
707 core::RefCountPtr<KernelAndDevice> kernel;
708 auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
709
710 // Run all the registered rewrite pass after the placement, regardless whether
711 // the placement is successful or not. The passes can either create new ops
712 // (without placement) or update some fields of the input op.
713 std::unique_ptr<tensorflow::EagerOperation> out_op;
714 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
715 EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
716 if (out_op) {
717 op = out_op.get();
718 // If the out op doesn't have device, either because it is a new op or
719 // the op wasn't placed successfully, then we do the placement again.
720 if (op->Device() == kVariantDeviceNull) {
721 status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
722 }
723 }
724 if (!status.ok()) return status;
725
726 int num_outputs = kernel->num_outputs();
727 TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
728
729 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
730 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
731 kernel->device()->name());
732 if (!logging::LogToListeners(msg)) {
733 LOG(INFO) << msg;
734 }
735 }
736
737 Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
738 // Since the operation failed, we need to Unref any outputs if they were
739 // allocated.
740 if (!s.ok()) {
741 for (int i = 0, end = num_outputs; i < end; ++i) {
742 if (retvals[i] != nullptr) {
743 retvals[i]->Unref();
744 }
745 }
746 }
747
748 return s;
749 }
750
751 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
752 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)753 Status MaybePackInputTensor(EagerOperation* op) {
754 if (op->is_function()) {
755 // Functions could take packed TensorHandles as inputs.
756 return Status::OK();
757 }
758 EagerContext& ctx = op->EagerContext();
759 const absl::InlinedVector<TensorHandle*, 4>* inputs;
760 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
761 for (int i = 0; i < inputs->size(); ++i) {
762 TensorHandle* handle = (*inputs)[i];
763 if (handle->Type() == TensorHandle::PACKED) {
764 EagerOperation pack_op(&ctx);
765 TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
766 /*remote=*/false, /*executor=*/nullptr));
767 pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
768 pack_op.MutableAttrs()->Set("T", handle->dtype);
769 for (int i = 0; i < handle->NumPackedHandles(); ++i) {
770 tensorflow::TensorHandle* h = nullptr;
771 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
772 TF_RETURN_IF_ERROR(pack_op.AddInput(h));
773 }
774 int num_retvals = 1;
775 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
776 TF_RETURN_IF_ERROR(
777 EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
778 tensorflow::TensorHandle* ret = retvals.at(0);
779 op->UpdateInput(i, ret);
780 ret->Unref();
781 }
782 }
783 return Status::OK();
784 }
785
786 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)787 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
788 EagerContext& ctx = op->EagerContext();
789
790 remote_op->set_id(ctx.RemoteMgr()->NextOpId());
791 remote_op->set_name(op->Name());
792
793 op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
794 remote_op->set_device(absl::get<Device*>(op->Device())->name());
795 remote_op->set_is_function(op->is_function());
796 }
797
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)798 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
799 const DataTypeVector& output_dtypes,
800 TensorHandle** retvals) {
801 if (remote_op.name() == "VarHandleOp") {
802 if (output_dtypes.size() != 1) {
803 return errors::Internal("VarHandleOp should only have one output.");
804 }
805 if (output_dtypes[0] != DT_RESOURCE) {
806 return errors::Internal(
807 "The output of VarHandleOp should be a DT_RESOURCE.");
808 }
809 AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
810 const AttrValue* dtype;
811 TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
812 const AttrValue* shape;
813 TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
814 retvals[0]->SetResourceHandleDtypeAndShape(
815 {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
816 }
817 return Status::OK();
818 }
819
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)820 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
821 int* num_retvals) {
822 EagerContext& ctx = op->EagerContext();
823
824 // TODO(fishx): Remove following code when lazy tensor copy is ready.
825 if (op->Device() == kVariantDeviceNull) {
826 tensorflow::Device* device = nullptr;
827 string device_name = op->DeviceName();
828 TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
829 op->SetDevice(device);
830 }
831
832 core::RefCountPtr<eager::EagerClient> eager_client;
833 uint64 context_id = ctx.GetContextId();
834 TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
835 string remote_task;
836 if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
837 return errors::InvalidArgument(
838 "Unable to find remote task corresponding to device ",
839 op->DeviceName());
840 }
841
842 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
843 request->set_context_id(context_id);
844
845 eager::Operation* remote_op = request->add_queue()->mutable_operation();
846
847 tensorflow::Device* op_device = absl::get<Device*>(op->Device());
848 {
849 profiler::TraceMe activity("CopyInputToExpectedDevice",
850 profiler::TraceMeLevel::kInfo);
851 const bool is_function = op->is_function();
852 const absl::InlinedVector<TensorHandle*, 4>* inputs;
853 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
854 for (int i = 0, end = inputs->size(); i < end; i++) {
855 tensorflow::TensorHandle* input = (*inputs)[i];
856 tensorflow::Device* input_device = input->device();
857 tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
858 const string* input_device_name = &input_device_or_cpu->name();
859 bool serialize_resource_dtype_and_shape = false;
860 if (op_device != input_device &&
861 // If the expected and actual devices are on the same task, don't
862 // explicitly copy, and instead depend on the copy to happen locally
863 // when the op is executed on the device.
864 !ctx.OnSameTask(op_device, input_device)) {
865 if (!is_function || input_device_or_cpu->IsLocal()) {
866 tensorflow::Device* remote_cpu_device;
867 TF_RETURN_IF_ERROR(
868 ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
869 // TODO(b/110044833): It's possible the same tensor gets copied to the
870 // remote device repeatedly.
871 // Always copy to the remote CPU so that the actual device can be
872 // correctly determined after the kernel is selected/instantiated,
873 // since the op might have its inputs on host memory.
874 TensorHandle* handle = input;
875 Device* handle_device = handle->DeviceOrHostCPU(ctx);
876 // If the input is already on the right device, then nothing to do.
877 if (remote_cpu_device != handle_device) {
878 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
879 &ctx, op, op_device, handle, i, handle_device,
880 remote_cpu_device, &handle));
881 op->UpdateInput(i, handle);
882 input = handle;
883 input_device = remote_cpu_device;
884 input_device_name = &remote_cpu_device->name();
885 // Unref handle since it has a ref as an input now
886 handle->Unref();
887 }
888 } else {
889 serialize_resource_dtype_and_shape =
890 (input->dtype == DT_RESOURCE) &&
891 (!input->HasResourceShapeMirror(op_device,
892 ctx.GetContextViewId()));
893 }
894 }
895 auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
896 // For a remote component function, a function execution request and an
897 // input generation request may come from different workers. We need to
898 // guarantee that the input generation request is processed before the
899 // function execution request, so wait until the remote input is ready
900 // before sending it to the multi-device function device.
901 const bool wait_until_ready = op->is_function();
902 TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
903 input, wait_until_ready, input_handle, input_device,
904 *input_device_name, serialize_resource_dtype_and_shape));
905 if (!input_handle->resource_dtypes_and_shapes().empty()) {
906 TF_RETURN_IF_ERROR(
907 input->AddResourceShapeMirror(op_device, input_handle->op_id(),
908 input_handle->output_num(), &ctx));
909 }
910 }
911 }
912
913 PrepareRemoteOp(remote_op, op);
914
915 DataTypeVector output_dtypes;
916 TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
917
918 const size_t num_outputs = output_dtypes.size();
919 if (num_outputs != *num_retvals) {
920 return errors::InvalidArgument(
921 "num_retvals does not match expected output dtypes");
922 }
923 *num_retvals = num_outputs;
924
925 const tensorflow::uint64 id = remote_op->id();
926 for (size_t i = 0; i < num_outputs; ++i) {
927 // TODO(nareshmodi): Change the callback to instead add the decref to a
928 // list of pending decrefs that we can send as a batch with the next
929 // execute.
930
931 // The device_ and resource_device_ of this TensorHandle might be
932 // incorrect. For multi-device functions, we don't know the output device
933 // until the function is instantiated on a remote worker. Luckily, we don't
934 // need to know the correct remote device here. We just need to know that it
935 // is remote. If we need copy this tensor to this process or run any ops
936 // which take this tensor as an input, block until the correct device is
937 // set.
938 const bool unknown_device = op->is_function();
939 retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
940 id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
941 }
942
943 // Store the data type and shape of a remote resource variable on the
944 // corresponding remote TensorHandle (output of 'VarHandleOp').
945 // If the variable is an input of a remote function, the function may need
946 // the type and shape during function instantiation. Store the type and
947 // shape on eager master and sent them to the default function device along
948 // with the EnqueueRequest.
949 TF_RETURN_IF_ERROR(
950 StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
951
952 auto& executor = op->Executor();
953 DVLOG(4) << "Execute remote eager op: " << op->Name()
954 << " (is async?: " << executor.Async() << ").";
955
956 const absl::InlinedVector<TensorHandle*, 4>* inputs;
957 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
958
959 std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
960 &op->EagerContext(), std::move(request), op_device,
961 ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
962 op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
963 *inputs, {retvals, num_outputs}));
964
965 if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
966 string msg = strings::StrCat(
967 "Executing op ", op->Name(), " on task ",
968 DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
969 if (!logging::LogToListeners(msg)) {
970 LOG(INFO) << msg;
971 }
972 }
973
974 Status s = executor.AddOrExecute(std::move(node));
975 // Since the operation failed, we need to Unref any outputs that were
976 // allocated.
977 if (!s.ok()) {
978 for (size_t i = 0; i < num_outputs; ++i) {
979 retvals[i]->Unref();
980 }
981 }
982
983 return s;
984 }
985 #endif // IS_MOBILE_PLATFORM
986
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerRemoteFunctionParams> & remote_func_params)987 Status GetKernelOutputs(
988 std::vector<EagerKernelRet>* outputs, int num_outputs,
989 TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
990 const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
991 for (int i = 0, end = num_outputs; i < end; ++i) {
992 if (retvals[i] == nullptr) {
993 EagerKernelRet& ret = (*outputs)[i];
994 Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
995 if (ret.index() == 0) {
996 retvals[i] = TensorHandle::CreateLocalHandle(
997 std::move(absl::get<Tensor>(ret)),
998 /* d= */ output_device,
999 /* op_device= */ kernel->device(),
1000 /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1001 } else {
1002 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1003 TF_RETURN_IF_ERROR(
1004 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1005 remote_func_params, ctx, &retvals[i]));
1006 #if !defined(IS_MOBILE_PLATFORM)
1007 TF_RETURN_IF_ERROR(
1008 retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1009 output_device, ctx->GetContextViewId()));
1010 #endif // IS_MOBILE_PLATFORM
1011 }
1012 } else {
1013 if (!kernel->IsFunction() &&
1014 TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1015 return errors::Internal(
1016 "Kernel output tensor handle has a different op device than the "
1017 "kernel. This should never happen.");
1018 }
1019 if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1020 retvals[i]->device())) {
1021 return errors::Internal(
1022 "Kernel output tensor handle locates on a different device than "
1023 "the specified kernel output device. This should never happen.");
1024 }
1025
1026 EagerKernelRet& ret = (*outputs)[i];
1027 if (ret.index() == 0) {
1028 TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1029 std::move(absl::get<Tensor>(ret)),
1030 ctx->CanonicalDevice(kernel->OutputDevice(i))));
1031 } else {
1032 #if defined(IS_MOBILE_PLATFORM)
1033 return errors::Unimplemented(
1034 "Remote outputs are not available on mobile devices.");
1035 #else // !IS_MOBILE_PLATFORM
1036 TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1037 absl::get<TensorShape>(ret), retvals[i]->device(),
1038 ctx->GetContextViewId()));
1039 #endif // !IS_MOBILE_PLATFORM
1040 }
1041 }
1042 }
1043 return Status::OK();
1044 }
1045
CollectGraphs(EagerContext * ctx)1046 void CollectGraphs(EagerContext* ctx) {
1047 mutex_lock ml(*ctx->MetadataMu());
1048
1049 GraphCollector* collector = ctx->GetGraphCollector();
1050 mutex_lock mll(collector->mu);
1051
1052 // Adding to partition graphs for backward compatibility.
1053 for (const auto& graph : collector->partitioned_graphs) {
1054 *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1055 }
1056
1057 if (collector->dirty) {
1058 auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1059 *function_graphs->mutable_post_optimization_graph() =
1060 collector->optimized_graph;
1061 *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1062 for (const auto& graph : collector->partitioned_graphs) {
1063 *function_graphs->add_partition_graphs() = graph;
1064 }
1065 }
1066
1067 collector->ClearGraphs();
1068 }
1069 } // namespace
1070
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1071 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1072 int* num_retvals) {
1073 profiler::TraceMe activity(
1074 [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
1075 profiler::TraceMeLevel::kInfo);
1076
1077 if (!op->Executor().Async()) {
1078 // In sync mode, always clear error to maintain the same behavior as before.
1079 // TODO(b/141004939): Remove this.
1080 op->Executor().ClearError();
1081 }
1082
1083 std::unique_ptr<tensorflow::EagerOperation> out_op;
1084 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1085 EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1086
1087 if (op->IsLocal()) {
1088 if (out_op) {
1089 op = out_op.get();
1090 }
1091 TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1092 return EagerLocalExecute(op, retvals, num_retvals);
1093 }
1094
1095 #if defined(IS_MOBILE_PLATFORM)
1096 return errors::Unimplemented(
1097 "Eager's remote execution is not available on mobile devices.");
1098 #else // !IS_MOBILE_PLATFORM
1099 if (out_op) {
1100 op = out_op.get();
1101 }
1102 return EagerRemoteExecute(op, retvals, num_retvals);
1103 #endif // !IS_MOBILE_PLATFORM
1104 }
1105
1106 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1107 Status EagerKernelExecute(
1108 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1109 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1110 const core::RefCountPtr<KernelAndDevice>& kernel,
1111 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1112 absl::Span<TensorHandle*> retvals,
1113 const absl::optional<ManagedStackTrace>& stack_trace) {
1114 profiler::TraceMe activity("EagerKernelExecute",
1115 profiler::TraceMeLevel::kInfo);
1116 std::vector<EagerKernelRet> outputs(1);
1117
1118 ExecuteNodeArgs inputs(op_inputs.size());
1119 TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1120 // TODO(apassos) figure out how to record stats for ops which are a part of
1121 // functions.
1122 // TODO(b/111859745): When we support recovering from kernel/device errors, we
1123 // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1124 // device. We don't call it now because it is an unneeded overhead (it
1125 // acquires a lock) and we can't recover from errors anyway.
1126 ScopedStepContainer* container = ctx->StepContainer();
1127 TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1128 cancellation_manager, remote_func_params,
1129 stack_trace));
1130 if (graph_collector != nullptr) {
1131 CollectGraphs(ctx);
1132 }
1133
1134 if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1135 return errors::Internal(
1136 "EagerKernelExecute returns a list of ", outputs.size(),
1137 " tensors but ", retvals.size(),
1138 " is expected. This should never "
1139 "happen. Please file a bug with the TensorFlow team.");
1140 }
1141 return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1142 kernel.get(), remote_func_params);
1143 }
1144
1145 namespace {
1146
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1147 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1148 EagerExecutor* executor, Device* dstd,
1149 bool mirror, TensorHandle** result) {
1150 TF_RETURN_IF_ERROR(executor->status());
1151 Device* d = ctx->CanonicalDevice(dstd);
1152 if (mirror && h->HasLocalMirror(d)) {
1153 h->Ref();
1154 *result = h;
1155 return Status::OK();
1156 }
1157
1158 bool async = executor->Async();
1159 if (mirror) {
1160 h->Ref();
1161 *result = h;
1162
1163 if (h->HasLocalMirror(d)) {
1164 return Status::OK();
1165 }
1166
1167 // We don't bother adding an empty local mirror in sync mode since we'll be
1168 // executing the operation directly and be calling AddLocalMirror. A
1169 // reference count is still needed which will be removed if the operation
1170 // fails.
1171 if (async) {
1172 Status s = h->AddEmptyLocalMirror(d);
1173 if (!s.ok()) {
1174 // If a mirror was added since we called HasLocalMirror then just return
1175 // since another thread has already added the mirror.
1176 if (s.code() == error::Code::ALREADY_EXISTS) {
1177 return Status::OK();
1178 }
1179
1180 // Remove the previously added reference count since adding the mirror
1181 // failed.
1182 h->Unref();
1183 *result = nullptr;
1184 return s;
1185 }
1186 }
1187 } else {
1188 *result = TensorHandle::CreateEmptyLocalHandle(
1189 d, dstd, h->resource_device(), h->dtype, ctx);
1190 }
1191
1192 Status s;
1193 if (async) {
1194 // Note that `h` may not be currently ready. However execution order will
1195 // make sure that `h` is ready before the copy is actually done.
1196 std::unique_ptr<EagerNode> node(
1197 new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1198 s = executor->AddOrExecute(std::move(node));
1199 } else {
1200 CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1201 s = executor->SyncExecute(&node);
1202 }
1203
1204 // Since the operation failed, we need to Unref any outputs that were
1205 // allocated.
1206 if (!s.ok()) {
1207 (*result)->Unref();
1208 }
1209
1210 return s;
1211 }
1212
1213 } // namespace
1214
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1215 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1216 EagerExecutor* executor, Device* device, bool mirror,
1217 TensorHandle** result) {
1218 TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1219 auto send_device = h->DeviceOrHostCPU(*ctx);
1220 bool sender_is_local = send_device->IsLocal();
1221
1222 bool receiver_is_local = device->IsLocal();
1223
1224 if (!executor->Async()) {
1225 // In sync mode, always clear error to maintain the same behavior as before.
1226 // TODO(b/141004939): Remove this.
1227 executor->ClearError();
1228 }
1229
1230 if (sender_is_local && receiver_is_local) {
1231 return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1232 } else {
1233 #if defined(IS_MOBILE_PLATFORM)
1234 return errors::Unimplemented(
1235 "Eager's remote execution is not available on mobile devices.");
1236 #else // !IS_MOBILE_PLATFORM
1237 uint64 recv_op_id = 0;
1238 if (receiver_is_local) {
1239 Device* d = ctx->CanonicalDevice(device);
1240 // TODO(gjn): Need to add support for async execution. Note if receiver
1241 // is local, we need to first add support in TensorHandle to wait on local
1242 // mirrors.
1243 if (mirror) {
1244 h->Ref();
1245 *result = h;
1246
1247 if (h->HasLocalMirror(d)) {
1248 return Status::OK();
1249 }
1250
1251 Status s = h->AddEmptyLocalMirror(d);
1252 if (!s.ok()) {
1253 // If a mirror was added since we called HasLocalMirror then just
1254 // return since another thread has already added the mirror.
1255 if (s.code() == error::Code::ALREADY_EXISTS) {
1256 return Status::OK();
1257 }
1258
1259 // Remove the previously added reference count since adding the mirror
1260 // failed.
1261 h->Unref();
1262 *result = nullptr;
1263 return s;
1264 }
1265 } else {
1266 *result = TensorHandle::CreateEmptyLocalHandle(
1267 /* d= */ d, /* op_device= */ device,
1268 /*resource_device=*/nullptr, h->dtype, ctx);
1269 }
1270 } else {
1271 if (mirror) {
1272 if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1273 h->Ref();
1274 *result = h;
1275 return Status::OK();
1276 }
1277 }
1278 string remote_task;
1279 if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
1280 return errors::InvalidArgument(
1281 "Unable to find remote task corresponding to device ",
1282 device->name());
1283 }
1284 recv_op_id = ctx->RemoteMgr()->NextOpId();
1285 if (mirror) {
1286 TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
1287 remote_task, ctx));
1288 h->Ref();
1289 *result = h;
1290 } else {
1291 *result = TensorHandle::CreateUnshapedRemoteHandle(
1292 recv_op_id, 0, remote_task, h->dtype, device, ctx);
1293 }
1294 }
1295
1296 auto node = std::make_unique<eager::RemoteCopyNode>(
1297 ctx, executor, h, result[0], device, recv_op_id);
1298 Status s = executor->AddOrExecute(std::move(node));
1299 if (!s.ok()) {
1300 result[0]->Unref();
1301 }
1302 return s;
1303 #endif // !IS_MOBILE_PLATFORM
1304 }
1305 }
1306
1307 namespace {
1308 // Low-level utility function to execute the kernel specified by `kernel` on
1309 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
1310 // Different from `EagerKernelExecute` that ties up the thread until the
1311 // underlying function finishes execute, this function does not block the thread
1312 // and could return before the function execution finishes. The provided
1313 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)1314 void EagerKernelExecuteAsync(
1315 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1316 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
1317 const core::RefCountPtr<KernelAndDevice> kernel,
1318 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1319 TensorHandle** retvals, int num_outputs, StatusCallback done) {
1320 auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
1321 auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
1322
1323 Status s = inputs->Init(ctx, op_inputs, kernel);
1324 if (!s.ok()) {
1325 done(s);
1326 return;
1327 }
1328
1329 kernel->Ref(); // Ownership of reference is transferred to the callback
1330 kernel->RunAsync(
1331 ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
1332 remote_func_params,
1333 [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
1334 remote_func_params, kernel_raw = kernel.get(),
1335 done = std::move(done)](const Status& s) {
1336 auto wrapped_done = [&](const Status& s) {
1337 kernel_raw->Unref();
1338 done(s);
1339 };
1340 if (!s.ok()) {
1341 wrapped_done(s);
1342 return;
1343 }
1344 if (graph_collector != nullptr) {
1345 CollectGraphs(ctx);
1346 }
1347 DCHECK_EQ(num_outputs, outputs->size());
1348 wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
1349 kernel_raw, remote_func_params));
1350 });
1351 }
1352 } // namespace
1353
1354 // Low-level utility to run the eager operation on local devices. Different from
1355 // `EagerLocalExecute` which blocks and waits for the finishing the op
1356 // execution, this method does not block the thread and could return before the
1357 // eager operation execution finishes. The provided `StatusCallback` will be
1358 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)1359 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
1360 int* num_retvals, StatusCallback done) {
1361 if (!op->IsLocal()) {
1362 done(errors::InvalidArgument(
1363 "Remote execution is not supported in async EagerLocalExecuteAsync"));
1364 return;
1365 }
1366
1367 ScopedMemoryDebugAnnotation op_annotation(
1368 op->op_name(), op->remote_func_params().has_value()
1369 ? op->remote_func_params().value().step_id.value_or(0)
1370 : 0);
1371 profiler::TraceMe activity(
1372 [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
1373 profiler::TraceMeLevel::kInfo);
1374 EagerContext& ctx = op->EagerContext();
1375
1376 core::RefCountPtr<KernelAndDevice> kernel;
1377 Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1378 if (!s.ok()) {
1379 done(s);
1380 return;
1381 }
1382
1383 int num_outputs = kernel->num_outputs();
1384 s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
1385 if (!s.ok()) {
1386 done(s);
1387 return;
1388 }
1389
1390 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1391 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1392 kernel->device()->name());
1393 if (!logging::LogToListeners(msg)) {
1394 LOG(INFO) << msg;
1395 }
1396 }
1397
1398 GraphCollector* graph_collector = nullptr;
1399 if (ctx.ShouldStoreGraphs()) {
1400 graph_collector = ctx.GetGraphCollector();
1401 }
1402
1403 for (int i = 0, end = num_outputs; i < end; ++i) {
1404 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1405 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1406 /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
1407 /* op_device= */ kernel->device(),
1408 /* resource_device= */ kernel->OutputResourceDevice(i),
1409 output_dtypes[i], &ctx);
1410 }
1411
1412 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1413 s = op->TensorHandleInputs(&inputs);
1414 if (!s.ok()) {
1415 done(s);
1416 return;
1417 }
1418 EagerKernelExecuteAsync(
1419 &ctx, *inputs, op->remote_func_params(), std::move(kernel),
1420 graph_collector, op->GetCancellationManager(), retvals, num_outputs,
1421 [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
1422 op->Clear();
1423 // Since the operation failed, we need to Unref any outputs if they were
1424 // allocated.
1425 if (!s.ok()) {
1426 for (int i = 0, end = num_outputs; i < end; ++i) {
1427 if (retvals[i] != nullptr) {
1428 retvals[i]->Unref();
1429 }
1430 }
1431 }
1432 done(s);
1433 });
1434 }
1435 } // namespace tensorflow
1436