1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17
18 #include <vector>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_set.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
25 #include "tensorflow/core/common_runtime/eager/execute_node.h"
26 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
27 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #ifndef __ANDROID__
33 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
34 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
35 #endif
36 #include "tensorflow/core/framework/step_stats.pb.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/flatset.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/util/ptr_util.h"
46
47 namespace tensorflow {
48
49 namespace {
50
51 // Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
52 // Copied here because we don't currently compile XLA on windows. So, can't
53 // depend on it directly.
54 const char* const kXlaCompileAttr = "_XlaCompile";
55
56 // Initializes the step stats if needed.
MaybeInitializeStepStats(StepStats * step_stats,EagerContext * ctx)57 void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
58 // Lazily initialize the RunMetadata with information about all devices if
59 // this is the first call.
60 while (step_stats->dev_stats_size() < ctx->devices()->size()) {
61 int device_idx = step_stats->dev_stats_size();
62 auto* dev_stats = step_stats->add_dev_stats();
63 dev_stats->set_device(ctx->devices()->at(device_idx)->name());
64 }
65 }
66
StepStatsDeviceIndex(StepStats * step_stats,EagerContext * ctx,Device * device)67 int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
68 Device* device) {
69 // Find the current device's index.
70 if (device == nullptr) {
71 device = ctx->HostCPU();
72 }
73 for (int i = 0; i < ctx->devices()->size(); ++i) {
74 if (ctx->devices()->at(i) == device ||
75 ctx->devices()->at(i)->name() == device->name()) {
76 return i;
77 }
78 }
79 // TODO(apassos) do not fall back to host CPU if device is unknown.
80 return 0;
81 }
82
83 // This function expects *handle to point to an existing tensor handle. The
84 // function will (maybe) update the *handle to be pointed to the newly copied
85 // tensor handle.
86 //
87 // The passed in *handle will be Unreffed if it is replaced.
88 //
89 // `op_device_name` is passed in explicitly because `op->device()` might be
90 // unset and we might have selected some specific device to run this op on.
MaybeCopyInputToExpectedDevice(EagerOperation * op,const string & op_device_name,int i,const Device * expected_input_device,RunMetadata * run_metadata,TensorHandle ** handle)91 Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
92 const string& op_device_name, int i,
93 const Device* expected_input_device,
94 RunMetadata* run_metadata,
95 TensorHandle** handle) {
96 EagerContext* ctx = op->EagerContext();
97 Device* handle_device = (*handle)->device();
98 const Device* actual_device =
99 handle_device == nullptr ? ctx->HostCPU() : handle_device;
100
101 if (expected_input_device != actual_device) {
102 switch (ctx->GetDevicePlacementPolicy()) {
103 case DEVICE_PLACEMENT_SILENT_FOR_INT32:
104 // TODO(xpan): See if we could bubble python related error up
105 // to python level.
106 if ((*handle)->dtype == DT_INT32) {
107 // Note: enabling silent copies of int32 tensors to match behavior
108 // of graph mode.
109 break;
110 }
111 TF_FALLTHROUGH_INTENDED;
112 case DEVICE_PLACEMENT_EXPLICIT:
113 return errors::InvalidArgument(
114 "Tensors on conflicting devices:"
115 " cannot compute ",
116 op->Name(), " as input #", i, " was expected to be on ",
117 expected_input_device->name(), " but is actually on ",
118 actual_device->name(), " (operation running on ", op_device_name,
119 ")",
120 " Tensors can be copied explicitly using .gpu() or .cpu() "
121 "methods,"
122 " or transparently copied by using tf.enable_eager_execution("
123 "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
124 "between devices"
125 " may slow down your model");
126 case DEVICE_PLACEMENT_WARN:
127 LOG(WARNING) << "before computing " << op->Name() << " input #" << i
128 << " was expected to be on "
129 << expected_input_device->name() << " but is actually on "
130 << actual_device->name() << " (operation running on "
131 << op_device_name
132 << "). This triggers a copy which can be a performance "
133 "bottleneck.";
134 break;
135 case DEVICE_PLACEMENT_SILENT: // Do nothing.
136 break;
137 }
138 // We are only here if the policy is warn or silent copies, so we should
139 // trigger a copy.
140 auto pre_time_nanos = Env::Default()->NowNanos();
141 TensorHandle* result_handle = nullptr;
142 Status status = EagerCopyToDevice(
143 *handle, ctx, expected_input_device->name().c_str(), &result_handle);
144 if (run_metadata != nullptr) {
145 auto* step_stats = run_metadata->mutable_step_stats();
146 MaybeInitializeStepStats(step_stats, ctx);
147 // Record the sending on the source device for now.
148 int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
149 auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
150 auto* node_stats = dev_stats->add_node_stats();
151 node_stats->set_node_name("_Send");
152 node_stats->set_all_start_micros(pre_time_nanos /
153 EnvTime::kMicrosToNanos);
154 node_stats->set_all_start_nanos(pre_time_nanos);
155 int64 now_nanos = Env::Default()->NowNanos();
156 node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
157 EnvTime::kMicrosToNanos);
158 node_stats->set_op_end_rel_nanos(now_nanos - pre_time_nanos);
159 node_stats->set_all_end_rel_micros((now_nanos - pre_time_nanos) /
160 EnvTime::kMicrosToNanos);
161 node_stats->set_all_end_rel_nanos(now_nanos - pre_time_nanos);
162 }
163 if (!status.ok()) {
164 if (result_handle != nullptr) result_handle->Unref();
165 return errors::Internal(
166 "Failed copying input tensor from ", actual_device->name(), " to ",
167 expected_input_device->name(), " in order to run ", op->Name(), ": ",
168 status.error_message());
169 }
170
171 (*handle)->Unref();
172 *handle = result_handle;
173 }
174 return Status::OK();
175 }
176
177 // `op_device_name` the name of the device on which the op will run, if any.
178 // For functions running using function library runtime, the device can be
179 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,const string & op_device_name,EagerOperation * op,const KernelAndDevice * kernel,RunMetadata * run_metadata)180 Status ValidateInputTypeAndPlacement(EagerContext* ctx,
181 const string& op_device_name,
182 EagerOperation* op,
183 const KernelAndDevice* kernel,
184 RunMetadata* run_metadata) {
185 if (kernel->num_inputs() != op->Inputs().size()) {
186 return errors::InvalidArgument("expected ", kernel->num_inputs(),
187 " inputs, got ", op->Inputs().size());
188 }
189 for (int i = 0; i < op->Inputs().size(); ++i) {
190 const Device* expected_device = kernel->InputDevice(i);
191 TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
192 op, op_device_name, i, expected_device, run_metadata,
193 &((*op->MutableInputs())[i])));
194 tensorflow::TensorHandle* handle = op->Inputs()[i];
195 if (handle->dtype != kernel->input_type(i)) {
196 return errors::InvalidArgument(
197 "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
198 " was expected to be a ", DataTypeString(kernel->input_type(i)),
199 " tensor but is a ", DataTypeString(handle->dtype), " tensor");
200 }
201 }
202 return Status::OK();
203 }
204
SelectDevice(const NodeDef & ndef,EagerContext * ctx,Device ** device)205 Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
206 PrioritizedDeviceTypeVector final_devices;
207 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
208 ctx->prioritized_device_type_list(), ndef, &final_devices));
209 if (final_devices.empty()) {
210 return errors::Internal("Could not find valid device for node.\nNode: ",
211 FormatNodeDefForError(ndef),
212 "\nAll kernels registered for op ", ndef.op(),
213 " :\n", KernelsRegisteredForOp(ndef.op()));
214 }
215 for (Device* d : *ctx->devices()) {
216 if (d->device_type() == final_devices[0].first.type_string()) {
217 *device = d;
218 return Status::OK();
219 }
220 }
221 return errors::Unknown("Could not find a device for node ",
222 FormatNodeDefForError(ndef));
223 }
224
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)225 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
226 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
227 const OpDef* op_def = nullptr;
228
229 const FunctionDef* function_def =
230 op->EagerContext()->FuncLibDef()->Find(op->Name());
231 if (function_def != nullptr) {
232 op_def = &(function_def->signature());
233 } else {
234 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
235 }
236
237 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
238
239 return Status::OK();
240 }
241
242 } // namespace
243
244 namespace {
IsLocal(EagerContext * ctx,tensorflow::Device * d)245 bool IsLocal(EagerContext* ctx, tensorflow::Device* d) {
246 if (d == nullptr || ctx->remote_device_mgr() == nullptr) return true;
247 tensorflow::Device* tmp;
248 return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok();
249 }
250
OnSameTask(EagerContext * ctx,Device * first,Device * second)251 bool OnSameTask(EagerContext* ctx, Device* first, Device* second) {
252 if (first == nullptr) first = ctx->HostCPU();
253 if (second == nullptr) second = ctx->HostCPU();
254 return first->parsed_name().job == second->parsed_name().job &&
255 first->parsed_name().replica == second->parsed_name().replica &&
256 first->parsed_name().task == second->parsed_name().task;
257 }
258
259 // Gets the CPU device on the task of device.
CPUDeviceOnTask(EagerContext * ctx,tensorflow::Device * device,tensorflow::Device ** cpu_device)260 Status CPUDeviceOnTask(EagerContext* ctx, tensorflow::Device* device,
261 tensorflow::Device** cpu_device) {
262 string cpu_device_name;
263 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
264 device->name(), &cpu_device_name));
265
266 return ctx->FindDeviceByName(cpu_device_name, cpu_device);
267 }
268
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)269 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
270 const tensorflow::Fprint128& b) {
271 return {tensorflow::FingerprintCat64(a.low64, b.low64),
272 tensorflow::FingerprintCat64(a.high64, b.high64)};
273 }
274
FindDeviceFromName(const EagerContext * ctx,const char * device_name,Device ** device)275 Status FindDeviceFromName(const EagerContext* ctx, const char* device_name,
276 Device** device) {
277 *device = ctx->HostCPU();
278 if (device_name == nullptr || strlen(device_name) == 0) {
279 return Status::OK();
280 }
281
282 auto status = ctx->local_device_mgr()->LookupDevice(device_name, device);
283 if (status.ok()) {
284 return status;
285 }
286
287 if (ctx->remote_device_mgr() != nullptr) {
288 return ctx->remote_device_mgr()->LookupDevice(device_name, device);
289 }
290
291 return status;
292 }
293
IsMultiDevice(const FunctionDef * fdef,const string & op_device)294 bool IsMultiDevice(const FunctionDef* fdef, const string& op_device) {
295 if (fdef == nullptr) {
296 // Primitive op.
297 return false;
298 }
299
300 // Run all functions as multi-device.
301 return true;
302
303 // We can eliminate some overhead by running simple functions using regular
304 // CallOp kernel. However, it is tricky to figure out which functions should
305 // be run using CallOp. Also, currently CallOp runs neither optimization
306 // passes (needed for TPU/XLA) nor grappler.
307 // Here are some cases where a function should be run in multi-device mode:
308 // - Function takes at least two resources on different devices.
309 // - Function takes a resource on deviceA and a body op explicitly placed
310 // on deviceB.
311 // - Function has a colocation constraint.
312 // - Function has an explicit device annotation (which might not be using
313 // full canonical device name) different from op_device. Note that false
314 // positives are ok.
315 // - Function has a node or a (node) attribute that can potentially make
316 // the function multi-device after a rewrite pass (e.g. various XLA/TPU
317 // special nodes and attributes)
318 }
319
AddInputDevicesToCacheKey(const EagerContext * ctx,const EagerOperation * op,std::vector<Device * > * input_dev_ptrs,Fprint128 * cache_key)320 Status AddInputDevicesToCacheKey(const EagerContext* ctx,
321 const EagerOperation* op,
322 std::vector<Device*>* input_dev_ptrs,
323 Fprint128* cache_key) {
324 input_dev_ptrs->reserve(op->Inputs().size());
325 Device* cpu_device = ctx->HostCPU();
326 for (TensorHandle* tensor_handle : op->Inputs()) {
327 string device_name;
328 if (tensor_handle->dtype == DT_RESOURCE) {
329 // Use the resource's actual device because it is the device that will
330 // influence partitioning the multi-device function.
331 const Tensor* tensor;
332 TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
333 const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
334 device_name = handle.device();
335
336 Device* input_device;
337 TF_RETURN_IF_ERROR(
338 FindDeviceFromName(ctx, device_name.c_str(), &input_device));
339 input_dev_ptrs->push_back(input_device);
340 } else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
341 input_dev_ptrs->push_back(cpu_device);
342 } else {
343 Device* device = tensor_handle->device();
344 device_name = device != nullptr ? device->name() : cpu_device->name();
345 input_dev_ptrs->push_back(device == nullptr ? cpu_device : device);
346 }
347 *cache_key = FingerprintCat128(*cache_key, Fingerprint128(device_name));
348 }
349 return Status::OK();
350 }
351
352 // There are a lot of references to devices in this function and around.
353 // Here is what they mean:
354 // EagerOperation::Device(): The device on which the user requested the op
355 // be executed, except if we had to change the device due to resource inputs
356 // or CPU pinning. If the user did not request a device, the op does not
357 // take resources, and we did not pin it to CPU, the device can be nullptr.
358 // KernelAndDevice::Device(): The first time we see an op (combined with
359 // its attributes), we need to create a KernelAndDevice object for it.
360 // If op->Device() is a nullptr, we select a device for the op when
361 // creating the KernelAndDevice. A concrete device will always be selected
362 // here except when `op` is a function to be executed using function library
363 // runtime. In this case, we don't select a device because running
364 // a function with explicitly requested device has different behavior than
365 // running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,gtl::InlinedVector<TensorHandle *,2> * retvals,int * num_retvals)366 Status EagerLocalExecute(EagerOperation* op,
367 gtl::InlinedVector<TensorHandle*, 2>* retvals,
368 int* num_retvals) {
369 const string unspecified_device_name("<unspecified>");
370 EagerContext* ctx = op->EagerContext();
371 auto status = ctx->GetStatus();
372 if (!status.ok()) return status;
373 Device* device = op->Device();
374
375 const string& maybe_unspecified_device_name =
376 device == nullptr ? unspecified_device_name : device->name();
377 Fprint128 cache_key =
378 op->MutableAttrs()->CacheKey(maybe_unspecified_device_name);
379
380 bool is_multi_device_function = IsMultiDevice(
381 ctx->FindFunctionDef(op->Name()), maybe_unspecified_device_name);
382
383 std::vector<Device*> input_dev_ptrs;
384 if (is_multi_device_function) {
385 TF_RETURN_IF_ERROR(
386 AddInputDevicesToCacheKey(ctx, op, &input_dev_ptrs, &cache_key));
387 }
388
389 KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
390 if (kernel == nullptr) {
391 VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
392 << maybe_unspecified_device_name;
393 // If we are running a function on explicitly requested TPU,
394 // compile it with XLA.
395 // Note that it is not ideal, but currently ok, to set this
396 // attribute after computing the kernel cache key above.
397 bool compile_with_xla = false;
398 if (op->is_function() && device != nullptr &&
399 (device->device_type() == "TPU" || device->device_type() == "XLA_GPU" ||
400 device->device_type() == "XLA_CPU")) {
401 op->MutableAttrs()->Set(kXlaCompileAttr, true);
402 compile_with_xla = true;
403 }
404 bool run_function_with_flr = is_multi_device_function && !compile_with_xla;
405
406 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
407 if (!run_function_with_flr && device == nullptr) {
408 status = SelectDevice(ndef, ctx, &device);
409 if (!status.ok()) return status;
410 }
411 const string& device_name =
412 device == nullptr ? unspecified_device_name : device->name();
413 if (ctx->LogDevicePlacement()) {
414 LOG(INFO) << "Executing op " << ndef.op() << " in device " << device_name;
415 } else {
416 VLOG(1) << "Executing op " << ndef.op() << " in device " << device_name;
417 }
418
419 FunctionLibraryRuntime* flr =
420 device == nullptr ? nullptr : ctx->func_lib(device);
421 if (device != nullptr && flr == nullptr) {
422 return errors::Unavailable(
423 "Unable to find a FunctionLibraryRuntime corresponding to device ",
424 device->name());
425 }
426 auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
427 : ctx->runner();
428 GraphCollector* graph_collector = nullptr;
429 if (ctx->ShouldStoreGraphs()) {
430 graph_collector = ctx->GetGraphCollector();
431 }
432 // Treat the function as multi_device only when we are not compiling
433 // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
434 // will create an XlaLaunchOp kernel to compile and run the function.
435 if (run_function_with_flr) {
436 // Multi-device functions don't use the rendezvous from eager context.
437 // If we use that rendezvous, multiple concurrent calls to the same
438 // function will likely result in collisions. However, this also means
439 // that we don't support legitimate sending/receiving across function
440 // boundary.
441 VLOG(2) << "Running " << ndef.op() << " using multi-device function. "
442 << "compile_with_xla=" << compile_with_xla
443 << ". Full node_def=" << ndef.DebugString();
444 kernel = new KernelAndDeviceFunc(
445 flr, ctx->pflr(), std::move(input_dev_ptrs), runner,
446 ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
447 } else {
448 VLOG(2) << "Running " << ndef.op() << " using op kernel. "
449 << "compile_with_xla=" << compile_with_xla
450 << ". Full node_def=" << ndef.DebugString();
451 kernel = new KernelAndDeviceOp(
452 ctx->GetRendezvous(), ctx->LogMemory(), flr, runner,
453 ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
454 }
455
456 status = kernel->Init(ndef, graph_collector);
457 if (!status.ok()) {
458 delete kernel;
459 return status;
460 }
461
462 ctx->AddKernelToCache(cache_key, kernel);
463 }
464 const DataTypeVector& output_dtypes = kernel->output_dtypes();
465 const int output_dtypes_size = static_cast<int>(output_dtypes.size());
466 if (output_dtypes_size > *num_retvals) {
467 return errors::InvalidArgument("Expecting ", output_dtypes.size(),
468 " outputs, but *num_retvals is ",
469 *num_retvals);
470 }
471 *num_retvals = output_dtypes_size;
472 const string& device_name = kernel->device() == nullptr
473 ? unspecified_device_name
474 : kernel->device()->name();
475 status = ValidateInputTypeAndPlacement(
476 ctx, device_name, op, kernel,
477 ctx->ShouldStoreStepStats() ? ctx->RunMetadataProto() : nullptr);
478 if (!status.ok()) return status;
479 std::unique_ptr<NodeExecStats> maybe_stats;
480 StepStats* maybe_step_stats = nullptr;
481 GraphCollector* graph_collector = nullptr;
482 if (ctx->ShouldStoreGraphs()) {
483 graph_collector = ctx->GetGraphCollector();
484 }
485 if (ctx->ShouldStoreStepStats()) {
486 maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
487 int64 now_nanos = Env::Default()->NowNanos();
488 maybe_stats.reset(new NodeExecStats);
489 maybe_stats->set_node_name(op->Name());
490 maybe_stats->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
491 maybe_stats->set_all_start_nanos(now_nanos);
492 maybe_stats->set_op_start_rel_micros(0);
493 maybe_stats->set_op_start_rel_nanos(0);
494 maybe_stats->set_scheduled_micros(now_nanos / EnvTime::kMicrosToNanos);
495 maybe_stats->set_scheduled_nanos(now_nanos);
496 // TODO(apassos) track referenced tensors
497 }
498 retvals->resize(*num_retvals);
499 if (ctx->Async()) {
500 // Note that for async mode, execution order will make sure that all
501 // input handles are ready before executing them.
502 // TODO(agarwal): Consider executing "cheap" kernels inline for
503 // performance.
504 tensorflow::uint64 id = ctx->NextId();
505 for (int i = 0; i < *num_retvals; ++i) {
506 (*retvals)[i] = new TensorHandle(
507 id, /* d= */ kernel->OutputDevice(i),
508 /* op_device= */ kernel->device(),
509 /* resource_device= */ kernel->OutputResourceDevice(i),
510 output_dtypes[i], ctx);
511 }
512 EagerNode* node = new ExecuteNode(id, ctx, op->Inputs(), kernel,
513 maybe_stats.release(), maybe_step_stats,
514 graph_collector, output_dtypes, *retvals);
515 ctx->ExecutorAdd(node);
516 } else {
517 // Execute checks if retvals[i] is nullptr or not to figure if it needs to
518 // allocate it.
519 status = EagerKernelExecute(ctx, op->Inputs(), kernel, maybe_stats.get(),
520 maybe_step_stats, graph_collector,
521 retvals->data(), *num_retvals);
522 }
523
524 return status;
525 }
526
527 #ifndef __ANDROID__
GetRemoteTensorDestructor(EagerContext * ctx,eager::EagerClient * eager_client,uint64 context_id,uint64 op_id,int output_num)528 std::function<void()> GetRemoteTensorDestructor(
529 EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
530 uint64 op_id, int output_num) {
531 return [ctx, eager_client, context_id, op_id, output_num]() {
532 if (!ctx->HasActiveRemoteContext(context_id)) {
533 // This means that this tensor was pointing to a remote device, which
534 // has been changed out from under us. Simply return since there is
535 // nothing we can do.
536 return tensorflow::Status::OK();
537 }
538
539 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
540 request->set_context_id(context_id);
541
542 auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
543 handle_to_decref->set_op_id(op_id);
544 handle_to_decref->set_output_num(output_num);
545
546 if (ctx->Async()) {
547 tensorflow::uint64 id = ctx->NextId();
548 auto* node =
549 new eager::RemoteExecuteNode(id, std::move(request), eager_client);
550 ctx->ExecutorAdd(node);
551 } else {
552 eager::EnqueueRequest* actual_request = request.release();
553 eager::EnqueueResponse* response = new eager::EnqueueResponse;
554 eager_client->EnqueueAsync(
555 actual_request, response,
556 [actual_request, response](const tensorflow::Status& s) {
557 delete actual_request;
558 delete response;
559 });
560 }
561
562 return tensorflow::Status::OK();
563 };
564 }
565 #endif
566
567 // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
568 // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
569 // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
570 //
571 // However, in some configurations the node that has the tensor to be copied
572 // isn't running a server (WorkerService RPC interface). For such cases,
573 // this function enables sending tensors using the EagerService.SendTensor RPC
574 // *on the receiver*.
EagerRemoteSendTensor(EagerContext * ctx,TensorHandle * h,Device * recv_device,TensorHandle ** result)575 Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
576 Device* recv_device, TensorHandle** result) {
577 #ifdef __ANDROID__
578 return errors::Unimplemented(
579 "Eager's remote execution is not available on Android devices.");
580 #else
581 eager::EagerClient* eager_client;
582 uint64 context_id;
583 TF_RETURN_IF_ERROR(
584 ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
585
586 eager::SendTensorRequest request;
587 eager::SendTensorResponse response;
588
589 request.set_context_id(context_id);
590 request.set_op_id(ctx->NextId());
591 request.set_device_name(recv_device->name());
592
593 Device* tensor_handle_device = h->device();
594
595 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
596 // copy it to the CPU before copying it out.
597 // TODO(nareshmodi): this is currently slow, but can be fixed by making
598 // tensor handles aware of more than one device.
599 TensorHandle* actual_handle;
600 if (tensor_handle_device != nullptr &&
601 tensor_handle_device->device_type() != "CPU") {
602 TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, ctx->HostCPU(), &actual_handle));
603 } else {
604 actual_handle = h;
605 actual_handle->Ref();
606 }
607
608 const Tensor* tensor;
609 TF_RETURN_IF_ERROR(actual_handle->Tensor(&tensor));
610 tensor->AsProtoTensorContent(request.add_tensors());
611
612 const tensorflow::uint64 id = request.op_id();
613
614 // TODO(nareshmodi): support making this call async.
615 Notification n;
616 Status status;
617 eager_client->SendTensorAsync(&request, &response,
618 [&n, &status](const Status& s) {
619 status = s;
620 n.Notify();
621 });
622 n.WaitForNotification();
623 if (!status.ok()) return status;
624
625 std::function<void()> destructor =
626 GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0);
627
628 *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0,
629 tensor->dtype(), std::move(destructor),
630 /*d=*/recv_device, /*op_device=*/recv_device,
631 /*resource_device=*/nullptr, ctx);
632 (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape()));
633
634 actual_handle->Unref();
635
636 return Status::OK();
637 #endif
638 }
639
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)640 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
641 int* num_retvals) {
642 #ifdef __ANDROID__
643 return errors::Unimplemented(
644 "Eager's remote execution is not available on Android devices.");
645 #else
646 EagerContext* ctx = op->EagerContext();
647
648 eager::EagerClient* eager_client;
649 uint64 context_id;
650 TF_RETURN_IF_ERROR(
651 ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
652
653 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
654 eager::EnqueueResponse response;
655
656 request->set_context_id(context_id);
657
658 auto* remote_op = request->add_queue()->mutable_operation();
659
660 for (int i = 0; i < op->Inputs().size(); i++) {
661 tensorflow::Device* input_device = op->Inputs()[i]->device();
662 if (op->Device() != input_device &&
663 // If the expected and actual devices are on the same task, don't
664 // explicitly copy, and instead depend on the copy to happen locally
665 // when the op is executed on the device.
666 !OnSameTask(ctx, op->Device(), input_device)) {
667 tensorflow::Device* remote_cpu_device;
668 TF_RETURN_IF_ERROR(
669 CPUDeviceOnTask(ctx, op->Device(), &remote_cpu_device));
670 // TODO(b/110044833): It's possible the same tensor gets copied to the
671 // remote device repeatedly.
672 // Always copy to the remote CPU so that the actual device can be
673 // correctly determined after the kernel is selected/instantiated, since
674 // the op might have its inputs on host memory.
675 TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
676 op, op->Device()->name(), i, remote_cpu_device,
677 /* run_metadata= */ nullptr, &(*op->MutableInputs())[i]));
678 }
679
680 tensorflow::TensorHandle* input = op->Inputs()[i];
681
682 tensorflow::int64 op_id;
683 int32 output_num;
684 TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
685
686 auto* remote_op_input = remote_op->add_inputs();
687 remote_op_input->set_op_id(op_id);
688 remote_op_input->set_output_num(output_num);
689 }
690
691 remote_op->set_id(op->EagerContext()->NextId());
692 remote_op->set_name(op->Name());
693 // Inputs set above.
694 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
695 remote_op->set_device(op->Device()->name());
696
697 DataTypeVector output_dtypes;
698 TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
699
700 if (*num_retvals != output_dtypes.size()) {
701 return errors::InvalidArgument(
702 "num_retvals does not match expected output dtypes");
703 }
704
705 tensorflow::Device* op_device = op->Device();
706
707 bool is_async = op->EagerContext()->Async();
708 uint64 remote_node_id = 0;
709
710 if (is_async) {
711 remote_node_id = op->EagerContext()->NextId();
712 }
713
714 const tensorflow::uint64 id = remote_op->id();
715 for (int i = 0; i < *num_retvals; i++) {
716 // TODO(nareshmodi): Change the callback to instead add the decref to a
717 // list of pending decrefs that we can send as a batch with the next
718 // execute.
719 std::function<void()> destructor =
720 GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i);
721
722 // The device_ and resource_device_ or this TensorHandle are not correct.
723 // It is pretty hard to make it correct because for multi-device functions,
724 // we don't know the output device until the function is instantiated.
725 // Luckily, we don't need to know the correct remote device here. We just
726 // need to know that it is remote. If we need to copy this tensor to this
727 // process, the remote end will know the correct device of this handle.
728 retvals[i] = new TensorHandle(
729 remote_op->id(), i, remote_node_id, output_dtypes[i],
730 std::move(destructor),
731 /*d=*/op_device, /*op_device=*/op_device,
732 /*resource_device=*/output_dtypes[i] == DT_RESOURCE ? op_device
733 : nullptr,
734 op->EagerContext());
735 }
736
737 if (is_async) {
738 // Copy the output handles, since the container for them might get
739 // destroyed.
740 gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
741 for (int i = 0; i < *num_retvals; i++) {
742 retvals_copy.push_back(retvals[i]);
743 retvals_copy[i]->Ref();
744 }
745 // Unable to capture via std::move, so bind instead.
746 auto* node = new eager::RemoteExecuteNode(
747 remote_node_id, std::move(request), eager_client, op->Inputs(),
748 std::bind(
749 [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
750 const Status& status, const eager::EnqueueResponse& response) {
751 if (!status.ok()) return;
752 for (int i = 0; i < retvals.size(); i++) {
753 retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
754 response.queue_response(0).shape(i)));
755 retvals[i]->Unref();
756 }
757 },
758 std::move(retvals_copy), std::placeholders::_1,
759 std::placeholders::_2));
760 op->EagerContext()->ExecutorAdd(node);
761 } else {
762 Notification n;
763 Status status;
764 eager_client->EnqueueAsync(request.get(), &response,
765 [&n, &status](const Status& s) {
766 status = s;
767 n.Notify();
768 });
769 n.WaitForNotification();
770
771 if (!status.ok()) return status;
772
773 for (int i = 0; i < *num_retvals; i++) {
774 retvals[i]->SetRemoteShape(
775 MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
776 }
777 }
778
779 return Status::OK();
780 #endif
781 }
782
783 // These ops are not pinnable since they generate data. It can be slower to
784 // generate and then copy the data instead of just generating the data on the
785 // device directly.
IsPinnableOp(const string & op_type)786 bool IsPinnableOp(const string& op_type) {
787 static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
788 "RandomUniform",
789 "RandomUniformInt",
790 "RandomStandardNormal",
791 "StatelessRandomUniform",
792 "StatelessRandomUniformInt",
793 "StatelessRandomNormal",
794 });
795
796 // XRT ops refer to per-device handles that are not safe to move between
797 // devices.
798 return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
799 !absl::StartsWith(op_type, "XRT");
800 }
801
802 // The Op device may be updated if:
803 // - A resource touching input is specified: all resource-touching ops run in
804 // the device the resource is, regardless of anything else that has been
805 // specified. This is identical to the graph mode behavior.
806 //
807 // - All op inputs are on the CPU, small (<64 elements) and integers
808 // (int32/int64). This can be disabled by setting the environment variable
809 // "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
MaybeUpdateOpDevice(EagerOperation * op)810 Status MaybeUpdateOpDevice(EagerOperation* op) {
811 if (op->is_function()) {
812 // Don't update the device of direct function calls.
813 // Particularly, if the user did not explicitly request any device for this
814 // function, picking a device would result in this device being the default
815 // for nodes inside the function. This is undesirable for multi-device
816 // functions since the not-explicitly-placed nodes inside the body will all
817 // end up on this default device.
818 return Status::OK();
819 }
820 EagerContext* ctx = op->EagerContext();
821 bool all_inputs_eligible_for_cpu_pinning =
822 ctx->PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
823 Device* op_device = op->Device() == nullptr ? ctx->HostCPU() : op->Device();
824 for (int i = 0; i < op->Inputs().size(); ++i) {
825 TensorHandle* tensor_handle = op->Inputs()[i];
826 if (tensor_handle->dtype == DT_RESOURCE) {
827 Device* resource_device = tensor_handle->resource_device();
828 VLOG(2) << "for op " << op->Name() << " input " << i << " "
829 << DataTypeString(tensor_handle->dtype)
830 << " input device = " << resource_device->name()
831 << ", op device = " << op_device->name();
832 // We check for `op->Device() == nullptr` because it can be later
833 // interpreted as unspecified device and a different device can
834 // be selected based on device priority. If any input to an op
835 // is a resource we must pin it to prevent different device selection.
836 // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
837 if (resource_device != op_device || op->Device() == nullptr) {
838 VLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
839 << "device of operation " << op->Name() << " to "
840 << resource_device->name() << " because input #" << i
841 << " is a resource in this device.";
842 op->SetDevice(resource_device);
843 }
844 all_inputs_eligible_for_cpu_pinning = false;
845 // No point in looking at other inputs. If there are other resources,
846 // they must have the same device and we already declared the op to be
847 // ineligible for CPU pinning.
848 break;
849 } else if (all_inputs_eligible_for_cpu_pinning) {
850 Device* input_device = tensor_handle->device();
851 input_device = input_device == nullptr ? ctx->HostCPU() : input_device;
852 VLOG(2) << "for op " << op->Name() << " input " << i << " "
853 << DataTypeString(tensor_handle->dtype)
854 << " input device = " << input_device->name()
855 << ", op device = " << op_device->name();
856
857 // Input is on CPU.
858 if (input_device != ctx->HostCPU()) {
859 all_inputs_eligible_for_cpu_pinning = false;
860 continue;
861 }
862
863 if (tensor_handle->dtype != DataType::DT_INT32 &&
864 tensor_handle->dtype != DataType::DT_INT64) {
865 all_inputs_eligible_for_cpu_pinning = false;
866 continue;
867 }
868
869 int64 num_elements;
870 TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
871 if (num_elements > 64) {
872 all_inputs_eligible_for_cpu_pinning = false;
873 }
874 }
875 }
876
877 // Ops without inputs are usually ops that generate a tensor in some way and
878 // usually require being present on whatever device they are scheduled on
879 // - for e.g. VarHandleOp or _Recv).
880 // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
881 // an op, but there is a GPU kernel?
882 if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
883 VLOG(1) << "Forcing op " << op->Name()
884 << " to be on the CPU since all input tensors have an "
885 "int32/int64 dtype, and are small (less than 64 elements).";
886 op->SetDevice(ctx->HostCPU());
887 }
888
889 return Status::OK();
890 }
891 } // namespace
892
EagerExecute(EagerOperation * op,gtl::InlinedVector<TensorHandle *,2> * retvals,int * num_retvals)893 Status EagerExecute(EagerOperation* op,
894 gtl::InlinedVector<TensorHandle*, 2>* retvals,
895 int* num_retvals) {
896 TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
897
898 bool op_is_local = IsLocal(op->EagerContext(), op->Device());
899
900 if (op_is_local) {
901 return EagerLocalExecute(op, retvals, num_retvals);
902 }
903
904 if (op->EagerContext()->LogDevicePlacement()) {
905 LOG(INFO) << "Executing op " << op->Name() << " in device "
906 << op->Device()->name();
907 }
908
909 return EagerRemoteExecute(op, retvals->data(), num_retvals);
910 }
911
EagerKernelExecute(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & op_inputs,KernelAndDevice * kernel,NodeExecStats * maybe_stats,StepStats * maybe_step_stats,GraphCollector * graph_collector,TensorHandle ** retvals,int num_retvals)912 Status EagerKernelExecute(EagerContext* ctx,
913 const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
914 KernelAndDevice* kernel, NodeExecStats* maybe_stats,
915 StepStats* maybe_step_stats,
916 GraphCollector* graph_collector,
917 TensorHandle** retvals, int num_retvals) {
918 std::vector<Tensor> outputs(1);
919
920 // If there are multiple references to a TensorHandle in 'op_inputs' we must
921 // increment the reference count of the corresponding Tensor or risk it being
922 // overwritten during kernel execution. The reference count is incremented
923 // below when we insert a copy of the Tensor into protected_tensors, and will
924 // be decremented once execution is complete.
925 std::vector<tensorflow::Tensor> protected_tensors;
926 for (int i = 0; i < op_inputs.size(); ++i) {
927 if (!op_inputs[i]->RefCountIsOne()) {
928 const Tensor* input_tensor = nullptr;
929 TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
930 protected_tensors.push_back(*input_tensor);
931 }
932 }
933
934 gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
935 for (int i = 0; i < op_inputs.size(); ++i) {
936 TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
937 }
938
939 // TODO(apassos) figure out how to record stats for ops which are a part of
940 // functions.
941 // TODO(agarwal): change Run to take vector of handles ?
942 ScopedStepContainer* container = ctx->StepContainer();
943 if (container == nullptr) {
944 TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
945 maybe_step_stats, graph_collector));
946 } else {
947 TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
948 maybe_stats, maybe_step_stats,
949 graph_collector));
950 }
951 if (graph_collector != nullptr) {
952 mutex_lock ml(*ctx->MetadataMu());
953 {
954 GraphCollector* collector = ctx->GetGraphCollector();
955 mutex_lock mll(collector->mu);
956
957 // Adding to partition graphs for backward compatibility.
958 for (const auto& graph : collector->partitioned_graphs) {
959 *ctx->RunMetadataProto()->add_partition_graphs() = graph;
960 }
961
962 if (collector->dirty) {
963 auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
964 *function_graphs->mutable_post_optimization_graph() =
965 collector->optimized_graph;
966 *function_graphs->mutable_pre_optimization_graph() =
967 collector->raw_graph;
968 for (const auto& graph : collector->partitioned_graphs) {
969 *function_graphs->add_partition_graphs() = graph;
970 }
971 }
972
973 collector->ClearGraphs();
974 }
975 }
976 if (maybe_stats != nullptr) {
977 int64 nanos = Env::Default()->NowNanos();
978 maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
979 maybe_stats->all_start_micros());
980 maybe_stats->set_op_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
981 maybe_stats->set_all_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
982 maybe_stats->all_start_micros());
983 maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
984 if (ctx->ShouldStoreStepStats()) {
985 mutex_lock ml(*ctx->MetadataMu());
986 {
987 auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
988 // Lazily initialize the RunMetadata with information about all devices
989 // if this is the first call.
990 while (step_stats->dev_stats_size() < ctx->devices()->size()) {
991 step_stats->add_dev_stats();
992 }
993 // Find the current device's index.
994 // If device is a nullptr (we are running a function without explicitly
995 // requested device), attribute the function runtime to CPU.
996 Device* attribution_device = kernel->device();
997 if (attribution_device == nullptr) {
998 attribution_device = ctx->HostCPU();
999 }
1000 int device_idx = 0;
1001 for (int i = 0; i < ctx->devices()->size(); ++i) {
1002 if (ctx->devices()->at(i) == attribution_device) {
1003 device_idx = i;
1004 break;
1005 }
1006 }
1007 // Populate the device stats for this device.
1008 auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
1009 dev_stats->set_device(attribution_device->name());
1010 *dev_stats->add_node_stats() = *maybe_stats;
1011 }
1012 }
1013 }
1014 DCHECK_EQ(num_retvals, outputs.size());
1015 for (int i = 0; i < num_retvals; ++i) {
1016 if (retvals[i] == nullptr) {
1017 retvals[i] =
1018 new TensorHandle(outputs[i], /* d= */ kernel->OutputDevice(i),
1019 /* op_device= */ kernel->device(), ctx);
1020 } else {
1021 // In the async case, the retval is not a nullptr, and its device is
1022 // already set since all TensorHandles always have their device set
1023 // (potentially to nullptr) during construction.
1024 DCHECK_EQ(kernel->device(), retvals[i]->op_device());
1025 DCHECK_EQ(kernel->OutputDevice(i), retvals[i]->device());
1026
1027 retvals[i]->SetTensor(outputs[i]);
1028 }
1029 }
1030 return Status::OK();
1031 }
1032
1033 namespace {
1034
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,Device * dstd,TensorHandle ** result)1035 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
1036 TensorHandle** result) {
1037 TF_RETURN_IF_ERROR(ctx->GetStatus());
1038 if (ctx->Async()) {
1039 // Note that `h` may not be currently ready. However execution order will
1040 // make sure that `h` is ready before the copy is actually done.
1041 CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
1042 TensorHandle* output = node->dst();
1043 // Note that calling Add makes `node` accessible by the EagerExecutor
1044 // thread. So further accesses need to be thread-safe.
1045 ctx->ExecutorAdd(node);
1046 *result = output;
1047 return Status::OK();
1048 } else {
1049 TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, result));
1050 return Status::OK();
1051 }
1052 }
1053
ExecuteSend(EagerContext * ctx,tensorflow::Device * device,TensorHandle * h,StringPiece wire_id,const string & recv_device)1054 Status ExecuteSend(EagerContext* ctx, tensorflow::Device* device,
1055 TensorHandle* h, StringPiece wire_id,
1056 const string& recv_device) {
1057 const tensorflow::AttrTypeMap* types;
1058 bool is_function = false;
1059 TF_RETURN_IF_ERROR(
1060 tensorflow::AttrTypeMapForOp("_Send", &types, &is_function));
1061 DCHECK(!is_function);
1062 tensorflow::EagerOperation op(ctx, "_Send", /*is_function=*/false, types);
1063
1064 op.AddInput(h);
1065
1066 op.SetDevice(device);
1067
1068 op.MutableAttrs()->Set("tensor_name", wire_id);
1069 op.MutableAttrs()->Set("send_device", device->name());
1070 op.MutableAttrs()->Set(
1071 "send_device_incarnation",
1072 static_cast<int64>(device->attributes().incarnation()));
1073 op.MutableAttrs()->Set("recv_device", recv_device);
1074 op.MutableAttrs()->Set("client_terminated", false);
1075
1076 op.MutableAttrs()->Set("T", h->dtype);
1077
1078 int num_outputs = 0;
1079 gtl::InlinedVector<TensorHandle*, 2> retvals;
1080
1081 return EagerExecute(&op, &retvals, &num_outputs);
1082 }
1083
ExecuteRecv(EagerContext * ctx,tensorflow::Device * device,DataType dtype,StringPiece wire_id,const string & send_device,int64 send_device_incarnation,TensorHandle ** result)1084 Status ExecuteRecv(EagerContext* ctx, tensorflow::Device* device,
1085 DataType dtype, StringPiece wire_id,
1086 const string& send_device, int64 send_device_incarnation,
1087 TensorHandle** result) {
1088 const tensorflow::AttrTypeMap* types;
1089 bool is_function = false;
1090 TF_RETURN_IF_ERROR(
1091 tensorflow::AttrTypeMapForOp("_Recv", &types, &is_function));
1092 DCHECK(!is_function);
1093 tensorflow::EagerOperation op(ctx, "_Recv", /*is_function=*/false, types);
1094
1095 op.SetDevice(device);
1096
1097 op.MutableAttrs()->Set("tensor_name", wire_id);
1098 op.MutableAttrs()->Set("send_device", send_device);
1099 op.MutableAttrs()->Set("send_device_incarnation", send_device_incarnation);
1100 op.MutableAttrs()->Set("recv_device", device->name());
1101 op.MutableAttrs()->Set("client_terminated", false);
1102
1103 op.MutableAttrs()->Set("tensor_type", dtype);
1104
1105 int num_outputs = 1;
1106 gtl::InlinedVector<TensorHandle*, 2> retvals(num_outputs);
1107
1108 TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_outputs));
1109
1110 *result = retvals.at(0);
1111
1112 return Status::OK();
1113 }
1114
1115 // This gets a unique wire ID. We add a random identifier so that if the
1116 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()1117 string GetUniqueWireID() {
1118 static tensorflow::uint64 random_seed = random::New64();
1119 static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
1120 static tensorflow::int64 wireid GUARDED_BY(wireid_mutex) = 0;
1121 tensorflow::mutex_lock l(wireid_mutex);
1122 return strings::StrCat(random_seed, "_", wireid++);
1123 }
1124
1125 } // namespace
1126
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,const char * device_name,TensorHandle ** result)1127 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1128 const char* device_name, TensorHandle** result) {
1129 tensorflow::Device* send_device = h->device();
1130
1131 if (send_device == nullptr) {
1132 send_device = ctx->HostCPU();
1133 }
1134
1135 bool sender_is_local = IsLocal(ctx, send_device);
1136
1137 tensorflow::Device* recv_device;
1138 TF_RETURN_IF_ERROR(FindDeviceFromName(ctx, device_name, &recv_device));
1139
1140 bool recver_is_local = IsLocal(ctx, recv_device);
1141
1142 if (sender_is_local && recver_is_local) {
1143 return LocalEagerCopyToDevice(h, ctx, recv_device, result);
1144 } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
1145 return EagerRemoteSendTensor(ctx, h, recv_device, result);
1146 } else {
1147 string wire_id = GetUniqueWireID();
1148
1149 TF_RETURN_IF_ERROR(
1150 ExecuteSend(ctx, send_device, h, wire_id, recv_device->name()));
1151
1152 return ExecuteRecv(ctx, recv_device, h->dtype, wire_id, send_device->name(),
1153 send_device->attributes().incarnation(), result);
1154 }
1155 }
1156 } // namespace tensorflow
1157