• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
17 
18 #include <memory>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38 #include "tensorflow/core/platform/denormal.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/notification.h"
42 #include "tensorflow/core/platform/setround.h"
43 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
47 #if !defined(IS_MOBILE_PLATFORM)
48 #include "tensorflow/core/grappler/grappler_item.h"
49 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
50 #endif  // !IS_MOBILE_PLATFORM
51 
52 namespace tensorflow {
53 
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const54 Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
55                                     Tensor* val) const {
56   if (index.sub_index >= 0) {
57     return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
58                                    " for argument ", index.index);
59   }
60   Tensor* arg = tensor_args_.at(index.index).tensor;
61   if (arg) {
62     *val = *arg;
63     return OkStatus();
64   } else {
65     return errors::NotFound("Argument ", index.index, " has no local tensor.");
66   }
67 }
68 
GetLocalTensors() const69 std::vector<Tensor> EagerKernelArgs::GetLocalTensors() const {
70   std::vector<Tensor> local_inputs;
71   local_inputs.reserve(tensor_args_.size());
72   for (const TensorValue& tensor_value : tensor_args_) {
73     local_inputs.push_back(*tensor_value.tensor);
74   }
75   return local_inputs;
76 }
77 
get_runner() const78 std::function<void(std::function<void()>)>* KernelAndDevice::get_runner()
79     const {
80   if (runner_) {
81     return runner_;
82   } else {
83     static auto* default_runner =
84         new std::function<void(std::function<void()>)>(
85             [](const std::function<void()>& f) { f(); });
86     return default_runner;
87   }
88 }
89 
~KernelAndDeviceFunc()90 KernelAndDeviceFunc::~KernelAndDeviceFunc() {
91   if (handle_ != kInvalidHandle) {
92     Status status = pflr_->ReleaseHandle(handle_);
93     if (!status.ok()) {
94       LOG(INFO) << "Ignoring error status when releasing multi-device function "
95                    "handle "
96                 << status.ToString();
97     }
98   }
99 }
100 
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)101 Status KernelAndDeviceOp::Init(const bool log_device_placement,
102                                const NodeDef& ndef,
103                                GraphCollector* graph_collector) {
104   OpKernel* k = nullptr;
105   if (flr_ == nullptr) {
106     return errors::Internal(
107         "A valid FunctionLibraryRuntime must be provided when running ops "
108         "based on OpKernel.");
109   }
110   std::shared_ptr<const NodeProperties> props;
111   TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
112       ndef, flr_->GetFunctionLibraryDefinition(), &props));
113   TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
114   kernel_.reset(k);
115   const auto* op_reg_data = OpRegistry::Global()->LookUp(ndef.op());
116   if (op_reg_data != nullptr) {
117     is_distributed_communication_op_ =
118         op_reg_data->op_def.is_distributed_communication();
119   }
120 
121   input_alloc_attrs_.resize(kernel_->num_inputs());
122   input_devices_.resize(kernel_->num_inputs(), device_);
123   for (size_t i = 0; i < input_alloc_attrs_.size(); ++i) {
124     bool host = kernel_->input_memory_types()[i] == tensorflow::HOST_MEMORY;
125     input_alloc_attrs_[i].set_on_host(host);
126     if (host && input_devices_[i]->device_type() != DEVICE_CPU) {
127       input_devices_[i] = host_cpu_device_;
128     }
129   }
130   output_alloc_attrs_.resize(kernel_->num_outputs());
131   for (size_t i = 0; i < output_alloc_attrs_.size(); ++i) {
132     output_alloc_attrs_[i].set_on_host(kernel_->output_memory_types()[i] ==
133                                        tensorflow::HOST_MEMORY);
134   }
135 
136   return OkStatus();
137 }
138 
InstantiateFunc(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)139 Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement,
140                                             const NodeDef& ndef,
141                                             GraphCollector* graph_collector) {
142   const OpDef* op_def = nullptr;
143   const FunctionDef* function_def;
144   if (flr_ == nullptr) {
145     // If function is being executed without an explicit device request,
146     // lookup the FunctionDef in the CPU's FLR. All FLRs share the same
147     // library.
148     function_def = pflr_->GetFLR(host_cpu_device_->name())
149                        ->GetFunctionLibraryDefinition()
150                        ->Find(ndef.op());
151   } else {
152     function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op());
153   }
154 
155   if (function_def != nullptr) {
156     op_def = &(function_def->signature());
157   } else {
158     TF_RETURN_IF_ERROR(OpDefForOp(ndef.op(), &op_def));
159   }
160   TF_RETURN_IF_ERROR(
161       InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_));
162 
163   FunctionLibraryRuntime::InstantiateOptions options;
164   options.target = device_ == nullptr ? "" : device_->name();
165   options.is_multi_device_function = true;
166   for (const Device* device : input_devices_) {
167     options.input_devices.push_back(device->name());
168   }
169   options.composite_devices = composite_devices_;
170   options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
171   if (outputs_on_op_device_) {
172     const FunctionLibraryDefinition* lib_def =
173         pflr_->GetFunctionLibraryDefinition();
174     const FunctionDef* fdef = lib_def->Find(ndef.op());
175     if (fdef == nullptr) {
176       return errors::InvalidArgument("Failed to find function ", ndef.op());
177     }
178     for (int i = 0; i < fdef->signature().output_arg_size(); ++i) {
179       options.output_devices.push_back(options.target);
180     }
181   }
182 
183   const auto& it = ndef.attr().find("executor_type");
184   if (it != ndef.attr().end()) {
185     options.executor_type = it->second.s();
186   }
187   const auto& is_component_fn_it = ndef.attr().find("is_component_function");
188   if (is_component_fn_it != ndef.attr().end()) {
189     options.is_component_function = is_component_fn_it->second.b();
190   }
191 #if !defined(IS_MOBILE_PLATFORM)
192   // Android tf library does not include grappler.
193   const auto& config_it = ndef.attr().find("config_proto");
194   if (config_it != ndef.attr().end()) {
195     if (!options.config_proto.ParseFromString(config_it->second.s())) {
196       return errors::InvalidArgument(
197           "Failed to parse config_proto attribute as tensorflow::ConfigProto "
198           "proto.");
199     }
200     grappler::GrapplerItem::OptimizationOptions optimization_options =
201         grappler::CreateOptOptionsForEager();
202 
203     options.optimize_graph_fn = std::bind(
204         grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
205         std::placeholders::_3, std::placeholders::_4, std::placeholders::_5,
206         options.config_proto, function_def->signature().name(),
207         optimization_options, std::placeholders::_6);
208   }
209 #endif  // !IS_MOBILE_PLATFORM
210   options.graph_collector = graph_collector;
211 
212   options.allow_small_function_optimizations =
213       allow_small_function_optimizations_;
214 
215   options.allow_control_flow_sync_execution =
216       allow_control_flow_sync_execution_;
217 
218   options.shape_inference_on_tfe_dialect_import =
219       shape_inference_on_tfe_dialect_import_;
220 
221   // In Eager mode we always inline all functions into the top-level
222   // function body graph, to get a single executable graph, that could be
223   // optimized across function boundaries (e.g. prune unused inputs and
224   // outputs in a function call chain). This is required to mimic graph mode
225   // execution, with aggressive pruning of nodes not in the transitive fanin
226   // of fetches.
227   options.config_proto.mutable_graph_options()
228       ->mutable_optimizer_options()
229       ->set_do_function_inlining(true);
230 
231   options.config_proto.set_log_device_placement(log_device_placement);
232 
233   options.int_args_and_retvals_on_device = int_args_and_retvals_on_device_;
234 
235   if (xla_compile_device_type_.has_value()) {
236     options.xla_compile_device_type = xla_compile_device_type_.value();
237   }
238 
239   TF_RETURN_IF_ERROR(
240       pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
241   return pflr_->IsCrossProcess(handle_, &is_cross_process_);
242 }
243 
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)244 Status KernelAndDeviceFunc::Init(const bool log_device_placement,
245                                  const NodeDef& ndef,
246                                  GraphCollector* graph_collector) {
247   TF_RETURN_IF_ERROR(
248       InstantiateFunc(log_device_placement, ndef, graph_collector));
249   return pflr_->GetOutputDevices(handle_, &output_devices_);
250 }
251 
252 namespace {
253 // In certain contexts (e.g. TPU async executions), the CancellationManager is
254 // used to shut down the device in error scenarios (as opposed to using the
255 // AsyncCompute's DoneCallback). This is handled through the
256 // {inc,dec}_num_deferred_ops_function.
257 struct OpExecutionState : public core::RefCounted {
258   // TODO(nareshmodi): consider refcounting the cancellation_manager.
259   CancellationManager cancellation_manager;
260 };
261 }  // anonymous namespace
262 
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)263 Status KernelAndDeviceOp::Run(
264     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
265     std::vector<EagerKernelRet>* outputs,
266     CancellationManager* cancellation_manager,
267     const absl::optional<EagerFunctionParams>& eager_func_params,
268     const absl::optional<ManagedStackTrace>& stack_trace,
269     CoordinationServiceAgent* coordination_service_agent) {
270   OpKernelContext::Params params;
271   params.device = device_;
272   params.frame_iter = FrameAndIter(0, 0);
273   params.inputs = *inputs.GetTensorValues();
274   params.op_kernel = kernel_.get();
275   params.resource_manager = device_->resource_manager();
276   params.input_alloc_attrs = input_alloc_attrs_;
277   params.output_attr_array = output_alloc_attrs_.data();
278   params.function_library = flr_;
279   params.slice_reader_cache = &slice_reader_cache_;
280   params.rendezvous = rendezvous_;
281   params.stack_trace = stack_trace;
282   OpExecutionState* op_execution_state = nullptr;
283 
284   CancellationManager default_cancellation_manager;
285   if (cancellation_manager) {
286     params.cancellation_manager = cancellation_manager;
287   } else if (kernel_->is_deferred()) {
288     op_execution_state = new OpExecutionState;
289     params.cancellation_manager = &op_execution_state->cancellation_manager;
290     params.inc_num_deferred_ops_function = [op_execution_state]() {
291       op_execution_state->Ref();
292     };
293     params.dec_num_deferred_ops_function = [op_execution_state]() {
294       op_execution_state->Unref();
295     };
296   } else {
297     params.cancellation_manager = &default_cancellation_manager;
298   }
299 
300   params.log_memory = log_memory_;
301 
302   params.runner = get_runner();
303 
304   params.step_container = step_container;
305 
306   params.collective_executor =
307       collective_executor_ ? collective_executor_->get() : nullptr;
308 
309   params.coordination_service_agent = coordination_service_agent;
310 
311   OpKernelContext context(&params);
312 
313   {
314     port::ScopedFlushDenormal flush;
315     port::ScopedSetRound round(FE_TONEAREST);
316     // 'AnnotatedTraceMe' will trace both scheduling time on host and execution
317     // time on device of the OpKernel.
318     profiler::AnnotatedTraceMe activity(
319         [&] { return kernel_->TraceString(context, /*verbose=*/false); },
320         profiler::TraceMeLevel::kInfo);
321     device_->Compute(kernel_.get(), &context);
322   }
323 
324   // Clean up execution op_execution_state if deferred ops aren't running.
325   if (op_execution_state != nullptr) {
326     op_execution_state->Unref();
327   }
328 
329   Status s = context.status();
330   if (TF_PREDICT_FALSE(!s.ok())) {
331     if (errors::IsUnavailable(s) && !is_distributed_communication_op_) {
332       s = errors::ReplaceErrorFromNonCommunicationOps(s, kernel_->name());
333     }
334     return s;
335   }
336 
337   if (outputs != nullptr) {
338     outputs->clear();
339     for (int i = 0; i < context.num_outputs(); ++i) {
340       const auto* output_tensor = context.mutable_output(i);
341       if (output_tensor != nullptr) {
342         outputs->push_back(Tensor(*output_tensor));
343       } else {
344         outputs->push_back(Tensor());
345       }
346     }
347   }
348   return OkStatus();
349 }
350 
351 std::shared_ptr<FunctionLibraryRuntime::Options>
PrepareForRun(ScopedStepContainer * step_container,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)352 KernelAndDeviceFunc::PrepareForRun(
353     ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs,
354     CancellationManager* cancellation_manager,
355     const absl::optional<EagerFunctionParams>& eager_func_params,
356     const absl::optional<ManagedStackTrace>& stack_trace,
357     CoordinationServiceAgent* coordination_service_agent) {
358   std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
359   if (eager_func_params.has_value()) {
360     const EagerFunctionParams& params = eager_func_params.value();
361     if (params.step_id.has_value()) {
362       // If the function is a remote component of a cross-process function,
363       // re-use the step id as its parent function's.
364       opts = std::make_shared<FunctionLibraryRuntime::Options>(
365           params.step_id.value());
366     } else {
367       opts = std::make_shared<FunctionLibraryRuntime::Options>();
368     }
369     // Reuse the op id if it exists.
370     if (params.op_id != kInvalidOpId) {
371       opts->op_id = params.op_id;
372     }
373   } else {
374     opts = std::make_shared<FunctionLibraryRuntime::Options>();
375     if (get_op_id_ && is_cross_process_) {
376       // If the function is a cross-process function and the remote execution
377       // goes through eager service, create an eager op id for the function.
378       opts->op_id = get_op_id_();
379     }
380   }
381 
382   // We don't pass rendezvous from eager context because we can get tensor
383   // name collisions in send/recv ops when running multiple instances
384   // of the same multi-device function concurrently.
385   Rendezvous* rendezvous = rendezvous_creator_(opts->step_id);
386   opts->rendezvous = rendezvous;
387   opts->create_rendezvous = false;
388 
389   // Create a cancellation manager to be used by FLR options if caller does not
390   // pass in one. If the caller does provide one, pass it to process FLR and the
391   // locally created one will be unused.
392   std::shared_ptr<CancellationManager> local_cm;
393   if (cancellation_manager) {
394     opts->cancellation_manager = cancellation_manager;
395   } else {
396     opts->cancellation_manager = new CancellationManager;
397   }
398   opts->allow_dead_tensors = true;
399   opts->step_container = step_container;
400   opts->collective_executor =
401       collective_executor_ ? collective_executor_->get() : nullptr;
402   opts->stack_trace = stack_trace;
403 
404   opts->stats_collector = nullptr;
405   opts->runner = get_runner();
406   opts->coordination_service_agent = coordination_service_agent;
407 
408   outputs->clear();
409   return opts;
410 }
411 
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,const absl::optional<ManagedStackTrace> & stack_trace,CoordinationServiceAgent * coordination_service_agent)412 Status KernelAndDeviceFunc::Run(
413     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
414     std::vector<EagerKernelRet>* outputs,
415     CancellationManager* cancellation_manager,
416     const absl::optional<EagerFunctionParams>& eager_func_params,
417     const absl::optional<ManagedStackTrace>& stack_trace,
418     CoordinationServiceAgent* coordination_service_agent) {
419   profiler::TraceMe activity("KernelAndDeviceFunc::Run",
420                              profiler::TraceMeLevel::kInfo);
421   // Don't try to handle packed or remote inputs synchronously.
422   if (inputs.HasRemoteOrPackedInputs() || eager_func_params.has_value()) {
423     Notification n;
424     Status status;
425     RunAsync(step_container, inputs, outputs, cancellation_manager,
426              eager_func_params, coordination_service_agent,
427              [&status, &n](Status s) {
428                status = s;
429                n.Notify();
430              });
431     n.WaitForNotification();
432     return status;
433   }
434   std::shared_ptr<FunctionLibraryRuntime::Options> opts =
435       PrepareForRun(step_container, outputs, cancellation_manager,
436                     eager_func_params, stack_trace, coordination_service_agent);
437 
438   std::vector<Tensor> rets;
439   Status s;
440   {
441     port::ScopedFlushDenormal flush;
442     port::ScopedSetRound round(FE_TONEAREST);
443     s.Update(pflr_->RunSync(*opts, handle_, inputs.GetLocalTensors(), &rets));
444   }
445 
446   if (cancellation_manager == nullptr) {
447     delete opts->cancellation_manager;
448   }
449   static_cast<Rendezvous*>(opts->rendezvous)->Unref();
450   outputs->reserve(rets.size());
451   for (auto& v : rets) {
452     outputs->push_back(std::move(v));
453   }
454   return s;
455 }
456 
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,CoordinationServiceAgent * coordination_service_agent,std::function<void (const Status &)> done)457 void KernelAndDeviceFunc::RunAsync(
458     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
459     std::vector<EagerKernelRet>* outputs,
460     CancellationManager* cancellation_manager,
461     const absl::optional<EagerFunctionParams>& eager_func_params,
462     CoordinationServiceAgent* coordination_service_agent,
463     std::function<void(const Status&)> done) {
464   profiler::TraceMe activity("KernelAndDeviceFunc::RunAsync",
465                              profiler::TraceMeLevel::kInfo);
466   std::shared_ptr<FunctionLibraryRuntime::Options> opts = PrepareForRun(
467       step_container, outputs, cancellation_manager, eager_func_params,
468       absl::nullopt, coordination_service_agent);
469 
470   pflr_->Run(
471       *opts, handle_, inputs, outputs,
472       [opts, cancellation_manager, done = std::move(done)](const Status& s) {
473         if (cancellation_manager == nullptr) {
474           delete opts->cancellation_manager;
475         }
476         static_cast<Rendezvous*>(opts->rendezvous)->Unref();
477         done(s);
478       });
479 }
480 
OutputDevice(int idx) const481 tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
482   if (kernel_->output_memory_types()[idx] == HOST_MEMORY) {
483     return nullptr;
484   }
485   return device_;
486 }
487 
OutputDevice(int idx) const488 tensorflow::Device* KernelAndDeviceFunc::OutputDevice(int idx) const {
489   if (output_dtypes_[idx] == DT_RESOURCE) {
490     return nullptr;
491   }
492   return output_devices_[idx];
493 }
494 
OutputResourceDevice(int idx) const495 tensorflow::Device* KernelAndDeviceOp::OutputResourceDevice(int idx) const {
496   if (kernel_->output_type(idx) == DT_RESOURCE) {
497     return device_;
498   }
499   return nullptr;
500 }
501 
OutputResourceDevice(int idx) const502 tensorflow::Device* KernelAndDeviceFunc::OutputResourceDevice(int idx) const {
503   if (output_dtypes_[idx] == DT_RESOURCE) {
504     return output_devices_[idx];
505   }
506   return nullptr;
507 }
508 
InputDevice(int i) const509 Device* KernelAndDeviceOp::InputDevice(int i) const {
510   return input_devices_[i];
511 }
512 
InputDevice(int i) const513 Device* KernelAndDeviceFunc::InputDevice(int i) const {
514   if ((input_dtypes_[i] == DT_RESOURCE) &&
515       (composite_devices_.find(input_devices_[i]->name()) ==
516        composite_devices_.end())) {
517     return host_cpu_device_;
518   } else {
519     return input_devices_[i];
520   }
521 }
522 
523 }  // namespace tensorflow
524