1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
17
18 #include <memory>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38 #include "tensorflow/core/platform/denormal.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/notification.h"
42 #include "tensorflow/core/platform/setround.h"
43 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
47 #if !defined(IS_MOBILE_PLATFORM)
48 #include "tensorflow/core/grappler/grappler_item.h"
49 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
50 #endif // !IS_MOBILE_PLATFORM
51
52 namespace tensorflow {
53
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const54 Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
55 Tensor* val) const {
56 if (index.sub_index >= 0) {
57 return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
58 " for argument ", index.index);
59 }
60 Tensor* arg = tensor_args_.at(index.index).tensor;
61 if (arg) {
62 *val = *arg;
63 return OkStatus();
64 } else {
65 return errors::NotFound("Argument ", index.index, " has no local tensor.");
66 }
67 }
68
GetLocalTensors() const69 std::vector<Tensor> EagerKernelArgs::GetLocalTensors() const {
70 std::vector<Tensor> local_inputs;
71 local_inputs.reserve(tensor_args_.size());
72 for (const TensorValue& tensor_value : tensor_args_) {
73 local_inputs.push_back(*tensor_value.tensor);
74 }
75 return local_inputs;
76 }
77
get_runner() const78 std::function<void(std::function<void()>)>* KernelAndDevice::get_runner()
79 const {
80 if (runner_) {
81 return runner_;
82 } else {
83 static auto* default_runner =
84 new std::function<void(std::function<void()>)>(
85 [](const std::function<void()>& f) { f(); });
86 return default_runner;
87 }
88 }
89
~KernelAndDeviceFunc()90 KernelAndDeviceFunc::~KernelAndDeviceFunc() {
91 if (handle_ != kInvalidHandle) {
92 Status status = pflr_->ReleaseHandle(handle_);
93 if (!status.ok()) {
94 LOG(INFO) << "Ignoring error status when releasing multi-device function "
95 "handle "
96 << status.ToString();
97 }
98 }
99 }
100
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)101 Status KernelAndDeviceOp::Init(const bool log_device_placement,
102 const NodeDef& ndef,
103 GraphCollector* graph_collector) {
104 OpKernel* k = nullptr;
105 if (flr_ == nullptr) {
106 return errors::Internal(
107 "A valid FunctionLibraryRuntime must be provided when running ops "
108 "based on OpKernel.");
109 }
110 std::shared_ptr<const NodeProperties> props;
111 TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
112 ndef, flr_->GetFunctionLibraryDefinition(), &props));
113 TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
114 kernel_.reset(k);
115 const auto* op_reg_data = OpRegistry::Global()->LookUp(ndef.op());
116 if (op_reg_data != nullptr) {
117 is_distributed_communication_op_ =
118 op_reg_data->op_def.is_distributed_communication();
119 }
120
121 input_alloc_attrs_.resize(kernel_->num_inputs());
122 input_devices_.resize(kernel_->num_inputs(), device_);
123 for (size_t i = 0; i < input_alloc_attrs_.size(); ++i) {
124 bool host = kernel_->input_memory_types()[i] == tensorflow::HOST_MEMORY;
125 input_alloc_attrs_[i].set_on_host(host);
126 if (host && input_devices_[i]->device_type() != DEVICE_CPU) {
127 input_devices_[i] = host_cpu_device_;
128 }
129 }
130 output_alloc_attrs_.resize(kernel_->num_outputs());
131 for (size_t i = 0; i < output_alloc_attrs_.size(); ++i) {
132 output_alloc_attrs_[i].set_on_host(kernel_->output_memory_types()[i] ==
133 tensorflow::HOST_MEMORY);
134 }
135
136 return OkStatus();
137 }
138
InstantiateFunc(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)139 Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement,
140 const NodeDef& ndef,
141 GraphCollector* graph_collector) {
142 const OpDef* op_def = nullptr;
143 const FunctionDef* function_def;
144 if (flr_ == nullptr) {
145 // If function is being executed without an explicit device request,
146 // lookup the FunctionDef in the CPU's FLR. All FLRs share the same
147 // library.
148 function_def = pflr_->GetFLR(host_cpu_device_->name())
149 ->GetFunctionLibraryDefinition()
150 ->Find(ndef.op());
151 } else {
152 function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op());
153 }
154
155 if (function_def != nullptr) {
156 op_def = &(function_def->signature());
157 } else {
158 TF_RETURN_IF_ERROR(OpDefForOp(ndef.op(), &op_def));
159 }
160 TF_RETURN_IF_ERROR(
161 InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_));
162
163 FunctionLibraryRuntime::InstantiateOptions options;
164 options.target = device_ == nullptr ? "" : device_->name();
165 options.is_multi_device_function = true;
166 for (const Device* device : input_devices_) {
167 options.input_devices.push_back(device->name());
168 }
169 options.composite_devices = composite_devices_;
170 options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
171 if (outputs_on_op_device_) {
172 const FunctionLibraryDefinition* lib_def =
173 pflr_->GetFunctionLibraryDefinition();
174 const FunctionDef* fdef = lib_def->Find(ndef.op());
175 if (fdef == nullptr) {
176 return errors::InvalidArgument("Failed to find function ", ndef.op());
177 }
178 for (int i = 0; i < fdef->signature().output_arg_size(); ++i) {
179 options.output_devices.push_back(options.target);
180 }
181 }
182
183 const auto& it = ndef.attr().find("executor_type");
184 if (it != ndef.attr().end()) {
185 options.executor_type = it->second.s();
186 }
187 const auto& is_component_fn_it = ndef.attr().find("is_component_function");
188 if (is_component_fn_it != ndef.attr().end()) {
189 options.is_component_function = is_component_fn_it->second.b();
190 }
191 #if !defined(IS_MOBILE_PLATFORM)
192 // Android tf library does not include grappler.
193 const auto& config_it = ndef.attr().find("config_proto");
194 if (config_it != ndef.attr().end()) {
195 if (!options.config_proto.ParseFromString(config_it->second.s())) {
196 return errors::InvalidArgument(
197 "Failed to parse config_proto attribute as tensorflow::ConfigProto "
198 "proto.");
199 }
200 grappler::GrapplerItem::OptimizationOptions optimization_options =
201 grappler::CreateOptOptionsForEager();
202
203 options.optimize_graph_fn = std::bind(
204 grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
205 std::placeholders::_3, std::placeholders::_4, std::placeholders::_5,
206 options.config_proto, function_def->signature().name(),
207 optimization_options, std::placeholders::_6);
208 }
209 #endif // !IS_MOBILE_PLATFORM
210 options.graph_collector = graph_collector;
211
212 options.allow_small_function_optimizations =
213 allow_small_function_optimizations_;
214
215 options.allow_control_flow_sync_execution =
216 allow_control_flow_sync_execution_;
217
218 options.shape_inference_on_tfe_dialect_import =
219 shape_inference_on_tfe_dialect_import_;
220
221 // In Eager mode we always inline all functions into the top-level
222 // function body graph, to get a single executable graph, that could be
223 // optimized across function boundaries (e.g. prune unused inputs and
224 // outputs in a function call chain). This is required to mimic graph mode
225 // execution, with aggressive pruning of nodes not in the transitive fanin
226 // of fetches.
227 options.config_proto.mutable_graph_options()
228 ->mutable_optimizer_options()
229 ->set_do_function_inlining(true);
230
231 options.config_proto.set_log_device_placement(log_device_placement);
232
233 options.int_args_and_retvals_on_device = int_args_and_retvals_on_device_;
234
235 if (xla_compile_device_type_.has_value()) {
236 options.xla_compile_device_type = xla_compile_device_type_.value();
237 }
238
239 TF_RETURN_IF_ERROR(
240 pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
241 return pflr_->IsCrossProcess(handle_, &is_cross_process_);
242 }
243
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)244 Status KernelAndDeviceFunc::Init(const bool log_device_placement,
245 const NodeDef& ndef,
246 GraphCollector* graph_collector) {
247 TF_RETURN_IF_ERROR(
248 InstantiateFunc(log_device_placement, ndef, graph_collector));
249 return pflr_->GetOutputDevices(handle_, &output_devices_);
250 }
251
252 namespace {
253 // In certain contexts (e.g. TPU async executions), the CancellationManager is
254 // used to shut down the device in error scenarios (as opposed to using the
255 // AsyncCompute's DoneCallback). This is handled through the
256 // {inc,dec}_num_deferred_ops_function.
257 struct OpExecutionState : public core::RefCounted {
258 // TODO(nareshmodi): consider refcounting the cancellation_manager.
259 CancellationManager cancellation_manager;
260 };
261 } // anonymous namespace
262
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)263 Status KernelAndDeviceOp::Run(
264 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
265 std::vector<EagerKernelRet>* outputs,
266 CancellationManager* cancellation_manager,
267 const absl::optional<EagerFunctionParams>& eager_func_params,
268 const absl::optional<ManagedStackTrace>& stack_trace,
269 CoordinationServiceAgent* coordination_service_agent) {
270 OpKernelContext::Params params;
271 params.device = device_;
272 params.frame_iter = FrameAndIter(0, 0);
273 params.inputs = *inputs.GetTensorValues();
274 params.op_kernel = kernel_.get();
275 params.resource_manager = device_->resource_manager();
276 params.input_alloc_attrs = input_alloc_attrs_;
277 params.output_attr_array = output_alloc_attrs_.data();
278 params.function_library = flr_;
279 params.slice_reader_cache = &slice_reader_cache_;
280 params.rendezvous = rendezvous_;
281 params.stack_trace = stack_trace;
282 OpExecutionState* op_execution_state = nullptr;
283
284 CancellationManager default_cancellation_manager;
285 if (cancellation_manager) {
286 params.cancellation_manager = cancellation_manager;
287 } else if (kernel_->is_deferred()) {
288 op_execution_state = new OpExecutionState;
289 params.cancellation_manager = &op_execution_state->cancellation_manager;
290 params.inc_num_deferred_ops_function = [op_execution_state]() {
291 op_execution_state->Ref();
292 };
293 params.dec_num_deferred_ops_function = [op_execution_state]() {
294 op_execution_state->Unref();
295 };
296 } else {
297 params.cancellation_manager = &default_cancellation_manager;
298 }
299
300 params.log_memory = log_memory_;
301
302 params.runner = get_runner();
303
304 params.step_container = step_container;
305
306 params.collective_executor =
307 collective_executor_ ? collective_executor_->get() : nullptr;
308
309 params.coordination_service_agent = coordination_service_agent;
310
311 OpKernelContext context(¶ms);
312
313 {
314 port::ScopedFlushDenormal flush;
315 port::ScopedSetRound round(FE_TONEAREST);
316 // 'AnnotatedTraceMe' will trace both scheduling time on host and execution
317 // time on device of the OpKernel.
318 profiler::AnnotatedTraceMe activity(
319 [&] { return kernel_->TraceString(context, /*verbose=*/false); },
320 profiler::TraceMeLevel::kInfo);
321 device_->Compute(kernel_.get(), &context);
322 }
323
324 // Clean up execution op_execution_state if deferred ops aren't running.
325 if (op_execution_state != nullptr) {
326 op_execution_state->Unref();
327 }
328
329 Status s = context.status();
330 if (TF_PREDICT_FALSE(!s.ok())) {
331 if (errors::IsUnavailable(s) && !is_distributed_communication_op_) {
332 s = errors::ReplaceErrorFromNonCommunicationOps(s, kernel_->name());
333 }
334 return s;
335 }
336
337 if (outputs != nullptr) {
338 outputs->clear();
339 for (int i = 0; i < context.num_outputs(); ++i) {
340 const auto* output_tensor = context.mutable_output(i);
341 if (output_tensor != nullptr) {
342 outputs->push_back(Tensor(*output_tensor));
343 } else {
344 outputs->push_back(Tensor());
345 }
346 }
347 }
348 return OkStatus();
349 }
350
351 std::shared_ptr<FunctionLibraryRuntime::Options>
PrepareForRun(ScopedStepContainer * step_container,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)352 KernelAndDeviceFunc::PrepareForRun(
353 ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs,
354 CancellationManager* cancellation_manager,
355 const absl::optional<EagerFunctionParams>& eager_func_params,
356 const absl::optional<ManagedStackTrace>& stack_trace,
357 CoordinationServiceAgent* coordination_service_agent) {
358 std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
359 if (eager_func_params.has_value()) {
360 const EagerFunctionParams& params = eager_func_params.value();
361 if (params.step_id.has_value()) {
362 // If the function is a remote component of a cross-process function,
363 // re-use the step id as its parent function's.
364 opts = std::make_shared<FunctionLibraryRuntime::Options>(
365 params.step_id.value());
366 } else {
367 opts = std::make_shared<FunctionLibraryRuntime::Options>();
368 }
369 // Reuse the op id if it exists.
370 if (params.op_id != kInvalidOpId) {
371 opts->op_id = params.op_id;
372 }
373 } else {
374 opts = std::make_shared<FunctionLibraryRuntime::Options>();
375 if (get_op_id_ && is_cross_process_) {
376 // If the function is a cross-process function and the remote execution
377 // goes through eager service, create an eager op id for the function.
378 opts->op_id = get_op_id_();
379 }
380 }
381
382 // We don't pass rendezvous from eager context because we can get tensor
383 // name collisions in send/recv ops when running multiple instances
384 // of the same multi-device function concurrently.
385 Rendezvous* rendezvous = rendezvous_creator_(opts->step_id);
386 opts->rendezvous = rendezvous;
387 opts->create_rendezvous = false;
388
389 // Create a cancellation manager to be used by FLR options if caller does not
390 // pass in one. If the caller does provide one, pass it to process FLR and the
391 // locally created one will be unused.
392 std::shared_ptr<CancellationManager> local_cm;
393 if (cancellation_manager) {
394 opts->cancellation_manager = cancellation_manager;
395 } else {
396 opts->cancellation_manager = new CancellationManager;
397 }
398 opts->allow_dead_tensors = true;
399 opts->step_container = step_container;
400 opts->collective_executor =
401 collective_executor_ ? collective_executor_->get() : nullptr;
402 opts->stack_trace = stack_trace;
403
404 opts->stats_collector = nullptr;
405 opts->runner = get_runner();
406 opts->coordination_service_agent = coordination_service_agent;
407
408 outputs->clear();
409 return opts;
410 }
411
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)412 Status KernelAndDeviceFunc::Run(
413 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
414 std::vector<EagerKernelRet>* outputs,
415 CancellationManager* cancellation_manager,
416 const absl::optional<EagerFunctionParams>& eager_func_params,
417 const absl::optional<ManagedStackTrace>& stack_trace,
418 CoordinationServiceAgent* coordination_service_agent) {
419 profiler::TraceMe activity("KernelAndDeviceFunc::Run",
420 profiler::TraceMeLevel::kInfo);
421 // Don't try to handle packed or remote inputs synchronously.
422 if (inputs.HasRemoteOrPackedInputs() || eager_func_params.has_value()) {
423 Notification n;
424 Status status;
425 RunAsync(step_container, inputs, outputs, cancellation_manager,
426 eager_func_params, coordination_service_agent,
427 [&status, &n](Status s) {
428 status = s;
429 n.Notify();
430 });
431 n.WaitForNotification();
432 return status;
433 }
434 std::shared_ptr<FunctionLibraryRuntime::Options> opts =
435 PrepareForRun(step_container, outputs, cancellation_manager,
436 eager_func_params, stack_trace, coordination_service_agent);
437
438 std::vector<Tensor> rets;
439 Status s;
440 {
441 port::ScopedFlushDenormal flush;
442 port::ScopedSetRound round(FE_TONEAREST);
443 s.Update(pflr_->RunSync(*opts, handle_, inputs.GetLocalTensors(), &rets));
444 }
445
446 if (cancellation_manager == nullptr) {
447 delete opts->cancellation_manager;
448 }
449 static_cast<Rendezvous*>(opts->rendezvous)->Unref();
450 outputs->reserve(rets.size());
451 for (auto& v : rets) {
452 outputs->push_back(std::move(v));
453 }
454 return s;
455 }
456
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,CoordinationServiceAgent * coordination_service_agent,std::function<void (const Status &)> done)457 void KernelAndDeviceFunc::RunAsync(
458 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
459 std::vector<EagerKernelRet>* outputs,
460 CancellationManager* cancellation_manager,
461 const absl::optional<EagerFunctionParams>& eager_func_params,
462 CoordinationServiceAgent* coordination_service_agent,
463 std::function<void(const Status&)> done) {
464 profiler::TraceMe activity("KernelAndDeviceFunc::RunAsync",
465 profiler::TraceMeLevel::kInfo);
466 std::shared_ptr<FunctionLibraryRuntime::Options> opts = PrepareForRun(
467 step_container, outputs, cancellation_manager, eager_func_params,
468 absl::nullopt, coordination_service_agent);
469
470 pflr_->Run(
471 *opts, handle_, inputs, outputs,
472 [opts, cancellation_manager, done = std::move(done)](const Status& s) {
473 if (cancellation_manager == nullptr) {
474 delete opts->cancellation_manager;
475 }
476 static_cast<Rendezvous*>(opts->rendezvous)->Unref();
477 done(s);
478 });
479 }
480
OutputDevice(int idx) const481 tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
482 if (kernel_->output_memory_types()[idx] == HOST_MEMORY) {
483 return nullptr;
484 }
485 return device_;
486 }
487
OutputDevice(int idx) const488 tensorflow::Device* KernelAndDeviceFunc::OutputDevice(int idx) const {
489 if (output_dtypes_[idx] == DT_RESOURCE) {
490 return nullptr;
491 }
492 return output_devices_[idx];
493 }
494
OutputResourceDevice(int idx) const495 tensorflow::Device* KernelAndDeviceOp::OutputResourceDevice(int idx) const {
496 if (kernel_->output_type(idx) == DT_RESOURCE) {
497 return device_;
498 }
499 return nullptr;
500 }
501
OutputResourceDevice(int idx) const502 tensorflow::Device* KernelAndDeviceFunc::OutputResourceDevice(int idx) const {
503 if (output_dtypes_[idx] == DT_RESOURCE) {
504 return output_devices_[idx];
505 }
506 return nullptr;
507 }
508
InputDevice(int i) const509 Device* KernelAndDeviceOp::InputDevice(int i) const {
510 return input_devices_[i];
511 }
512
InputDevice(int i) const513 Device* KernelAndDeviceFunc::InputDevice(int i) const {
514 if ((input_dtypes_[i] == DT_RESOURCE) &&
515 (composite_devices_.find(input_devices_[i]->name()) ==
516 composite_devices_.end())) {
517 return host_cpu_device_;
518 } else {
519 return input_devices_[i];
520 }
521 }
522
523 } // namespace tensorflow
524