• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/captured_function.h"
16 
17 #include <utility>
18 
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/step_stats_collector.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function_handle_cache.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/kernels/data/dataset_utils.h"
29 #include "tensorflow/core/kernels/data/stats_utils.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/lib/random/random.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/notification.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41 #endif  // !IS_MOBILE_PLATFORM
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
47 const char kDataServiceDataset[] = "DataServiceDataset";
48 
49 // Simplistic implementation of the `StepStatsCollectorInterface` that only
50 // cares about collecting the CPU time needed to execute a captured function.
51 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
52  public:
IncrementProcessingTime(int64 delta)53   void IncrementProcessingTime(int64 delta) {
54     mutex_lock l(mu_);
55     processing_time_ += delta;
56   }
57 
CreateNodeExecStats(const NodeDef * node)58   NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
59     return new SimpleNodeExecStats(this);
60   }
61 
ReportAllocsOnResourceExhausted(const string & err)62   string ReportAllocsOnResourceExhausted(const string& err) override {
63     return "";
64   }
65 
processing_time()66   int64 processing_time() {
67     tf_shared_lock l(mu_);
68     return processing_time_;
69   }
70 
71  private:
72   class SimpleNodeExecStats : public NodeExecStatsInterface {
73    public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)74     explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
75         : step_stats_collector_(step_stats_collector) {}
76 
Done(const string & device)77     void Done(const string& device) override {
78       step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
79                                                      start_time_ns_);
80       delete this;
81     }
82 
RecordExecutorStarted()83     void RecordExecutorStarted() override {
84       start_time_ns_ = absl::GetCurrentTimeNanos();
85     }
86 
RecordComputeStarted()87     void RecordComputeStarted() override {}
88 
RecordComputeEnded()89     void RecordComputeEnded() override {}
90 
RecordExecutorEnded()91     void RecordExecutorEnded() override {
92       end_time_ns_ = absl::GetCurrentTimeNanos();
93     }
94 
TrackAllocations() const95     bool TrackAllocations() const override { return false; }
96 
SetMemory(OpKernelContext * ctx)97     void SetMemory(OpKernelContext* ctx) override {}
98 
SetOutput(int slot,const Tensor * tensor)99     void SetOutput(int slot, const Tensor* tensor) override {}
100 
SetScheduled(int64 nanos)101     void SetScheduled(int64 nanos) override {}
102 
103    private:
104     int64 start_time_ns_ = 0;
105     int64 end_time_ns_ = 0;
106     SimpleStepStatsCollector* step_stats_collector_;  // Not owned.
107   };
108 
109   mutex mu_;
110   int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
111 };
112 
GetCapturedInput(const CapturedFunction * const func,int index,const Tensor ** out)113 Status GetCapturedInput(const CapturedFunction* const func, int index,
114                         const Tensor** out) {
115   if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
116     return errors::OutOfRange(
117         "Out of range access to captured inputs for function ",
118         func->func().name(), ". Index: ", index,
119         ". Num captured inputs: ", func->captured_inputs().size());
120   }
121   *out = &func->captured_inputs()[index];
122   return Status::OK();
123 }
124 
RunShortCircuit(const ShortCircuitInfo & info,const std::vector<Tensor> & args,const CapturedFunction * const func,std::vector<Tensor> * rets)125 Status RunShortCircuit(const ShortCircuitInfo& info,
126                        const std::vector<Tensor>& args,
127                        const CapturedFunction* const func,
128                        std::vector<Tensor>* rets) {
129   VLOG(3) << "Running function " << func->func().name() << " short circuit";
130   const int num_args = args.size();
131   rets->reserve(info.indices.size());
132   for (size_t i = 0; i < info.indices.size(); ++i) {
133     if (info.indices[i] < num_args) {
134       rets->push_back(args[info.indices[i]]);
135     } else {
136       const Tensor* captured_input;
137       TF_RETURN_IF_ERROR(
138           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
139       rets->push_back(*captured_input);
140     }
141   }
142   return Status::OK();
143 }
144 
RunShortCircuit(const ShortCircuitInfo & info,std::vector<Tensor> && args,const CapturedFunction * const func,std::vector<Tensor> * rets)145 Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
146                        const CapturedFunction* const func,
147                        std::vector<Tensor>* rets) {
148   VLOG(3) << "Running function " << func->func().name() << " short circuit";
149   const int num_args = args.size();
150   rets->reserve(info.indices.size());
151   for (size_t i = 0; i < info.indices.size(); ++i) {
152     if (info.indices[i] < num_args) {
153       if (info.can_move[i]) {
154         rets->push_back(std::move(args[info.indices[i]]));
155       } else {
156         rets->push_back(args[info.indices[i]]);
157       }
158     } else {
159       const Tensor* captured_input;
160       TF_RETURN_IF_ERROR(
161           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
162       rets->push_back(*captured_input);
163     }
164   }
165   return Status::OK();
166 }
167 
CreateShortCircuitInfo(OpKernelConstruction * ctx,const NameAttrList & func,ShortCircuitInfo * info)168 Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
169                               const NameAttrList& func,
170                               ShortCircuitInfo* info) {
171   auto& indices = info->indices;
172 
173   FunctionLibraryRuntime::Handle fn_handle;
174   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
175       func.name(), AttrSlice(&func.attr()), &fn_handle));
176   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
177     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
178     if (!s.ok()) {
179       LOG(WARNING) << "Failed to release handle: " << s.error_message();
180     }
181   });
182 
183   // If the function contains any stateful operations, we conservatively execute
184   // the entire function.
185   if (ctx->function_library()->IsStateful(func.name())) {
186     return Status::OK();
187   }
188 
189   const FunctionBody* fn_body =
190       ctx->function_library()->GetFunctionBody(fn_handle);
191   indices.resize(fn_body->ret_nodes.size());
192 
193   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
194     Node* ret_node = fn_body->ret_nodes[i];
195     Node* ret_input_node;
196     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
197 
198     while (ret_input_node->def().op() == "Identity") {
199       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
200     }
201 
202     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
203       TF_RETURN_IF_ERROR(
204           GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
205     } else {
206       indices.clear();
207       break;
208     }
209   }
210 
211   // Compute the `can_move` vector.
212   if (!indices.empty()) {
213     auto& can_move = info->can_move;
214     std::map<int, int> last_use;
215     for (size_t i = 0; i < indices.size(); ++i) {
216       last_use[indices[i]] = i;
217     }
218     can_move.resize(indices.size());
219     for (int i = 0, end = indices.size(); i < end; ++i) {
220       can_move[i] = last_use[indices[i]] == i;
221     }
222   }
223 
224   return Status::OK();
225 }
226 
CreateFunctionLibraryDefinition(const FunctionLibraryDefinition * lib_def,const string & func_name,std::unique_ptr<FunctionLibraryDefinition> * result)227 Status CreateFunctionLibraryDefinition(
228     const FunctionLibraryDefinition* lib_def, const string& func_name,
229     std::unique_ptr<FunctionLibraryDefinition>* result) {
230   DCHECK(lib_def != nullptr);
231   const FunctionDef* fdef = lib_def->Find(func_name);
232   if (TF_PREDICT_FALSE(fdef == nullptr)) {
233     return errors::FailedPrecondition(strings::StrCat(
234         "Could not find required function definition ", func_name));
235   }
236   *result = absl::make_unique<FunctionLibraryDefinition>(
237       lib_def->ReachableDefinitions(*fdef));
238   return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
239 }
240 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)241 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
242                           const FunctionDef& function_def) {
243   if (!function_def.signature().is_stateful()) {
244     return Status::OK();
245   }
246 
247   for (const NodeDef& node_def : function_def.node_def()) {
248     TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
249   }
250   return Status::OK();
251 }
252 
253 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
254 // allowlist source dataset ops which have been marked stateful due to
255 // b/65524810. Also looks up the `op_def->name` in the global
256 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)257 bool IsOpAllowlisted(const OpDef* op_def) {
258   return (op_def->output_arg_size() == 1 &&
259           op_def->output_arg(0).type() == DT_VARIANT &&
260           (absl::EndsWith(op_def->name(), "Dataset") ||
261            absl::EndsWith(op_def->name(), "DatasetV2"))) ||
262          AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
263 }
264 
LookupFunction(const FunctionLibraryDefinition & lib_def,const string & name,const FunctionDef ** fdef)265 Status LookupFunction(const FunctionLibraryDefinition& lib_def,
266                       const string& name, const FunctionDef** fdef) {
267   *fdef = lib_def.Find(name);
268   if (*fdef == nullptr) {
269     return errors::InvalidArgument(
270         "Failed to find function ", name,
271         " in function library: ", lib_def.ToProto().DebugString());
272   }
273   return Status::OK();
274 }
275 
276 class CallFrameBase : public CallFrameInterface {
277  public:
CallFrameBase(DataTypeSlice ret_types)278   explicit CallFrameBase(DataTypeSlice ret_types)
279       : ret_types_(ret_types), retvals_(ret_types.size()) {}
280 
281   // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)282   Status ConsumeRetvals(std::vector<Tensor>* retvals) {
283     retvals->reserve(retvals_.size());
284     int i = 0;
285     for (auto&& val : retvals_) {
286       if (!val) {
287         return errors::Internal("No return value for index ", i, ".");
288       }
289       retvals->emplace_back(std::move(val.value()));
290       ++i;
291     }
292     return Status::OK();
293   }
294 
num_retvals() const295   size_t num_retvals() const override { return retvals_.size(); }
296 
297   // Callee methods.
SetRetval(int index,const Tensor & val)298   Status SetRetval(int index, const Tensor& val) override {
299     const int retvals_size = retvals_.size();
300     if (index < retvals_size && val.dtype() == ret_types_[index] &&
301         !retvals_[index]) {
302       retvals_[index] = val;
303       return Status::OK();
304     } else if (index >= retvals_size) {
305       return errors::InvalidArgument("Return value ", index,
306                                      " is out of range.");
307     } else if (val.dtype() != ret_types_[index]) {
308       return errors::InvalidArgument("Expected type ",
309                                      DataTypeString(ret_types_[index]),
310                                      " for return value ", index, " but got ",
311                                      DataTypeString(val.dtype()), ".");
312     } else {
313       return errors::Internal("Attempted to set return value ", index,
314                               " more than once.");
315     }
316   }
317 
318  private:
319   DataTypeSlice ret_types_;
320   std::vector<gtl::optional<Tensor>> retvals_;
321   TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
322 };
323 
324 class OwnedArgsCallFrame : public CallFrameBase {
325  public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)326   OwnedArgsCallFrame(std::vector<Tensor>&& args,
327                      const std::vector<Tensor>* captured_inputs,
328                      DataTypeSlice ret_types)
329       : CallFrameBase(ret_types),
330         args_(std::move(args)),
331         captured_inputs_(captured_inputs) {}
332 
num_args() const333   size_t num_args() const override {
334     return args_.size() + captured_inputs_->size();
335   }
336 
337   // Callee methods.
GetArg(int index,const Tensor ** val)338   Status GetArg(int index, const Tensor** val) override {
339     const int args_size = args_.size();
340     const int captured_inputs_size = captured_inputs_->size();
341     if (index < args_size) {
342       *val = &args_[index];
343       return Status::OK();
344     } else if (index < args_size + captured_inputs_size) {
345       *val = &(*captured_inputs_)[index - args_.size()];
346       return Status::OK();
347     } else {
348       return errors::InvalidArgument("Argument ", index, " is out of range.");
349     }
350   }
351 
352   // Since we own the argument tensors in `args_`, we can implement
353   // `ConsumeArg()` for those arguments.
ConsumeArg(int index,Tensor * val)354   void ConsumeArg(int index, Tensor* val) override {
355     DCHECK_GE(index, 0);
356     DCHECK_LT(index, args_.size());
357     *val = std::move(args_[index]);
358   }
CanConsumeArg(int index) const359   bool CanConsumeArg(int index) const override {
360     return index >= 0 && index < static_cast<int>(args_.size());
361   }
362 
363  private:
364   std::vector<Tensor> args_;
365   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
366 };
367 
368 class BorrowedArgsCallFrame : public CallFrameBase {
369  public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)370   BorrowedArgsCallFrame(const std::vector<Tensor>& args,
371                         const std::vector<Tensor>* captured_inputs,
372                         DataTypeSlice ret_types)
373       : CallFrameBase(ret_types),
374         args_(args),
375         captured_inputs_(captured_inputs) {}
376 
num_args() const377   size_t num_args() const override {
378     return args_.size() + captured_inputs_->size();
379   }
380 
381   // Callee methods.
GetArg(int index,const Tensor ** val)382   Status GetArg(int index, const Tensor** val) override {
383     const int args_size = args_.size();
384     const int captured_inputs_size = captured_inputs_->size();
385     if (index < args_size) {
386       *val = &args_[index];
387       return Status::OK();
388     } else if (index < args_size + captured_inputs_size) {
389       *val = &(*captured_inputs_)[index - args_size];
390       return Status::OK();
391     } else {
392       return errors::InvalidArgument("Argument ", index, " is out of range.");
393     }
394   }
395 
396  private:
397   const std::vector<Tensor>& args_;                   // Not owned.
398   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
399 };
400 
401 }  // namespace
402 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)403 Status IsNodeStateful(const FunctionLibraryDefinition& library,
404                       const NodeDef& node) {
405   const OpDef* op_def;
406 
407   // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
408   // `LookUpOpDef` errors here.
409   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
410       IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
411       op_def->name() == "Assert") {
412     return Status::OK();
413   }
414 
415   if (op_def->name() == "If") {
416     const FunctionDef* then_func =
417         library.Find(node.attr().at("then_branch").func().name());
418     const FunctionDef* else_func =
419         library.Find(node.attr().at("else_branch").func().name());
420     if (then_func != nullptr) {
421       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
422     }
423     if (else_func != nullptr) {
424       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
425     }
426     return Status::OK();
427   }
428 
429   if (op_def->name() == "While") {
430     const FunctionDef* cond_func =
431         library.Find(node.attr().at("cond").func().name());
432     const FunctionDef* body_func =
433         library.Find(node.attr().at("body").func().name());
434     if (cond_func != nullptr) {
435       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
436     }
437     if (body_func != nullptr) {
438       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
439     }
440     return Status::OK();
441   }
442 
443   return errors::FailedPrecondition(op_def->name(), " is stateful.");
444 }
445 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)446 Status MakeIteratorFromInputElement(
447     IteratorContext* ctx, const IteratorBase* parent,
448     const std::vector<Tensor>& input_element, int64 thread_index,
449     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
450     std::unique_ptr<IteratorBase>* out_iterator) {
451   return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
452                                       inst_captured_func, prefix, out_iterator,
453                                       /*node=*/nullptr);
454 }
455 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator,const std::shared_ptr<model::Node> & node)456 Status MakeIteratorFromInputElement(
457     IteratorContext* ctx, const IteratorBase* parent,
458     const std::vector<Tensor>& input_element, int64 thread_index,
459     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
460     std::unique_ptr<IteratorBase>* out_iterator,
461     const std::shared_ptr<model::Node>& node) {
462   std::vector<Tensor> return_values;
463 
464   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
465       ctx, input_element, &return_values, node));
466 
467   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
468         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
469     return errors::InvalidArgument(
470         "Function must return a single scalar of dtype DT_VARIANT.");
471   }
472 
473   // Retrieve the dataset that was created in `f`.
474   DatasetBase* returned_dataset;
475   TF_RETURN_IF_ERROR(
476       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
477 
478   // Create an iterator for the dataset that was returned by `f`.
479   std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
480   if (ctx->split_provider() == nullptr) {
481     return returned_dataset->MakeIterator(ctx, parent, iterator_prefix,
482                                           out_iterator);
483   }
484   // Strip out the split provider so that it doesn't apply to sub-iterators.
485   IteratorContext::Params params(ctx);
486   params.split_provider = nullptr;
487   return returned_dataset->MakeIterator(IteratorContext(std::move(params)),
488                                         parent, iterator_prefix, out_iterator);
489 }
490 
491 /* static */
Create(OpKernelConstruction * ctx,const string & func_name,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)492 Status FunctionMetadata::Create(
493     OpKernelConstruction* ctx, const string& func_name, Params params,
494     std::shared_ptr<FunctionMetadata>* out_metadata) {
495   NameAttrList func;
496   TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
497   return Create(ctx, std::move(func), params, out_metadata);
498 }
499 
Create(OpKernelConstruction * ctx,NameAttrList && func,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)500 Status FunctionMetadata::Create(
501     OpKernelConstruction* ctx, NameAttrList&& func, Params params,
502     std::shared_ptr<FunctionMetadata>* out_metadata) {
503   out_metadata->reset(new FunctionMetadata(std::move(func), params));
504   TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
505       ctx->function_library()->GetFunctionLibraryDefinition(),
506       (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
507   TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
508       ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
509   const FunctionDef* fdef;
510   TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
511                                     (*out_metadata)->func().name(), &fdef));
512 
513   auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
514   if (attr != fdef->attr().end() && attr->second.b()) {
515     VLOG(1) << "Disabling multi-device execution for a function that uses the "
516             << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
517     (*out_metadata)->use_multi_device_function_ = false;
518     return Status::OK();
519   }
520   auto validate_arg = [](const OpDef::ArgDef& arg) {
521     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
522       VLOG(1) << "Disabling multi-device execution for a function with "
523               << "a vector argument " << arg.name() << ".";
524       return false;
525     }
526     return true;
527   };
528   for (const auto& arg : fdef->signature().input_arg()) {
529     if (!validate_arg(arg)) {
530       (*out_metadata)->use_multi_device_function_ = false;
531       return Status::OK();
532     }
533   }
534   for (const auto& arg : fdef->signature().output_arg()) {
535     if (!validate_arg(arg)) {
536       (*out_metadata)->use_multi_device_function_ = false;
537       return Status::OK();
538     }
539   }
540   return Status::OK();
541 }
542 
543 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)544 Status CapturedFunction::Create(
545     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
546     const string& argument_name,
547     std::unique_ptr<CapturedFunction>* out_function) {
548   OpInputList inputs;
549   TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
550   std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
551   return Create(ctx, std::move(metadata), std::move(captured_inputs),
552                 out_function);
553 }
554 
555 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> && captured_inputs,std::unique_ptr<CapturedFunction> * out_function)556 Status CapturedFunction::Create(
557     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
558     std::vector<Tensor>&& captured_inputs,
559     std::unique_ptr<CapturedFunction>* out_function) {
560   *out_function = absl::WrapUnique(
561       new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
562   return Status::OK();
563 }
564 
AddToGraph(SerializationContext * ctx,DatasetBase::DatasetGraphDefBuilder * b,std::vector<Node * > * other_arguments,DataTypeVector * other_arguments_types) const565 Status CapturedFunction::AddToGraph(
566     SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
567     std::vector<Node*>* other_arguments,
568     DataTypeVector* other_arguments_types) const {
569   other_arguments->reserve(captured_inputs_.size());
570   other_arguments_types->reserve(captured_inputs_.size());
571   for (const Tensor& t : captured_inputs_) {
572     Node* node;
573     DatasetBase* input;
574     Status s = GetDatasetFromVariantTensor(t, &input);
575     if (s.ok()) {
576       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
577     } else {
578       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
579     }
580     other_arguments->emplace_back(node);
581     other_arguments_types->emplace_back(t.dtype());
582   }
583   TF_RETURN_IF_ERROR(
584       b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
585   return Status::OK();
586 }
587 
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)588 Status CapturedFunction::Instantiate(
589     IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
590                               instantiated_captured_function) {
591   // The context's runtime will be used for all subsequent calls.
592   FunctionLibraryRuntime* lib = ctx->flr();
593   FunctionLibraryRuntime::InstantiateOptions inst_opts;
594   inst_opts.lib_def = metadata_->lib_def();
595   inst_opts.create_kernels_eagerly = true;
596   inst_opts.default_device_to_target = metadata_->use_default_device();
597   inst_opts.config_proto =
598       lib->config_proto() ? *lib->config_proto() : ConfigProto();
599   if (!metadata_->use_inter_op_parallelism()) {
600     inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
601   }
602   inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
603 
604   // We infer the target device from the function library runtime.
605   DCHECK(lib->device() != nullptr);
606   inst_opts.target = lib->device()->name();
607 
608   // Maps from a CompositeDevice name to underlying physical device names.
609   absl::flat_hash_map<string, std::vector<string>> composite_devices;
610 
611   if (inst_opts.is_multi_device_function) {
612     // Compute devices of non-captured inputs.
613     //
614     // We infer the number of non-captured inputs by subtracting the number
615     // of captured inputs from the number of input arguments and we infer the
616     // input devices from the function library runtime.
617     const FunctionDef* fdef;
618     TF_RETURN_IF_ERROR(
619         LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
620     size_t num_non_captured_inputs =
621         fdef->signature().input_arg_size() - captured_inputs_.size();
622     for (size_t i = 0; i < num_non_captured_inputs; ++i) {
623       inst_opts.input_devices.push_back(inst_opts.target);
624     }
625     // Compute devices of captured inputs.
626     // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
627     Device* cpu_device;
628     TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
629     std::unordered_map<int, DtypeAndPartialTensorShape>&
630         input_resource_variable_dtypes_and_shapes =
631             inst_opts.input_resource_dtypes_and_shapes;
632     for (size_t i = 0; i < captured_inputs_.size(); ++i) {
633       const auto& input = captured_inputs_[i];
634       DataType dtype = input.dtype();
635       if (dtype == DT_RESOURCE) {
636         const auto& handles = input.flat<ResourceHandle>();
637         const ResourceHandle& handle0 = handles(0);
638         string composite_device;
639         auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
640         if (iter != fdef->arg_attr().end()) {
641           auto arg_attr = iter->second.attr().find("_composite_device");
642           if (arg_attr != iter->second.attr().end()) {
643             composite_device = arg_attr->second.s();
644           }
645         }
646         if (!composite_device.empty()) {
647           if (composite_devices.find(composite_device) ==
648               composite_devices.end()) {
649             for (int i = 0; i < handles.size(); ++i) {
650               composite_devices[composite_device].push_back(
651                   handles(i).device());
652             }
653           }
654           inst_opts.input_devices.push_back(composite_device);
655         } else {
656           inst_opts.input_devices.push_back(handle0.device());
657         }
658         const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
659         // Set dtypes and shapes for resource variable inputs.
660         if (!dtypes_and_shapes.empty()) {
661           input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
662                                                     i] =
663               dtypes_and_shapes.at(0);
664         }
665       } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
666         inst_opts.input_devices.push_back(cpu_device->name());
667       } else {
668         // Fall back to using the function library runtime device.
669         inst_opts.input_devices.push_back(inst_opts.target);
670       }
671     }
672 
673     for (const auto& it : composite_devices) {
674       inst_opts.composite_devices[it.first] = &it.second;
675     }
676 
677     for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
678       inst_opts.output_devices.push_back(inst_opts.target);
679     }
680 
681 #if !defined(IS_MOBILE_PLATFORM)
682     grappler::GrapplerItem::OptimizationOptions optimization_options;
683     optimization_options.allow_pruning_stateful_and_dataset_ops = false;
684     ConfigProto config_proto = inst_opts.config_proto;
685     // Layout optimizations are excluded because they assume that ops without
686     // explicit device assignment will be placed on GPU (if available) but
687     // that's not the case for operations within tf.data functions.
688     config_proto.mutable_graph_options()
689         ->mutable_rewrite_options()
690         ->set_layout_optimizer(RewriterConfig::OFF);
691     // TODO(b/120437209): Re-enable constant folding.
692     config_proto.mutable_graph_options()
693         ->mutable_rewrite_options()
694         ->set_constant_folding(RewriterConfig::OFF);
695     inst_opts.optimize_graph_fn =
696         std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
697                   std::placeholders::_2, std::placeholders::_3,
698                   std::placeholders::_4, std::placeholders::_5,
699                   std::move(config_proto), fdef->signature().name(),
700                   std::move(optimization_options), std::placeholders::_6);
701 #endif  // !IS_MOBILE_PLATFORM
702   }
703 
704   FunctionLibraryRuntime::Handle f_handle;
705   TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
706       metadata_->func().name(), AttrSlice(&metadata_->func().attr()), inst_opts,
707       &f_handle));
708 
709   DataTypeVector ret_types;
710   TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
711 
712   bool is_multi_device;
713   TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
714   return InstantiatedCapturedFunction::Create(
715       lib, f_handle, std::move(ret_types), *ctx->runner(), this,
716       is_multi_device, instantiated_captured_function);
717 }
718 
CheckExternalState() const719 Status CapturedFunction::CheckExternalState() const {
720   for (const auto& name : lib_def()->ListFunctionNames()) {
721     TF_RETURN_IF_ERROR(
722         IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
723   }
724   return Status::OK();
725 }
726 
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> captured_inputs)727 CapturedFunction::CapturedFunction(
728     std::shared_ptr<const FunctionMetadata> metadata,
729     std::vector<Tensor> captured_inputs)
730     : metadata_(std::move(metadata)),
731       captured_inputs_(std::move(captured_inputs)) {}
732 
IsMultiDevice(IteratorContext * ctx,bool * is_multi_device) const733 Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
734                                        bool* is_multi_device) const {
735   if (!metadata_->use_multi_device_function()) {
736     *is_multi_device = false;
737     return Status::OK();
738   }
739 
740   const FunctionDef* fdef;
741   TF_RETURN_IF_ERROR(
742       LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
743 
744   Device* current_device = ctx->flr()->device();
745   DeviceType current_device_type(current_device->device_type());
746   DeviceNameUtils::ParsedName current_device_name;
747   if (!DeviceNameUtils::ParseFullName(current_device->name(),
748                                       &current_device_name)) {
749     return errors::InvalidArgument("Failed to parse device name: ",
750                                    current_device->name());
751   }
752 
753   // Check if any of the captured inputs are placed on a device not compatible
754   // with the current device. For non-captured inputs, we assume they are placed
755   // on the current device.
756   for (const auto& input : captured_inputs_) {
757     DataType dtype = input.dtype();
758     if (dtype == DT_RESOURCE) {
759       const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
760       DeviceNameUtils::ParsedName resource_device_name;
761       if (!DeviceNameUtils::ParseFullName(handle.device(),
762                                           &resource_device_name)) {
763         return errors::InvalidArgument("Failed to parse device name: ",
764                                        handle.device());
765       }
766       if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
767                                                   resource_device_name)) {
768         *is_multi_device = true;
769         return Status::OK();
770       }
771     }
772   }
773 
774   // Check if all ops could be placed on the current device.
775   for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
776     const FunctionDef* fdef;
777     TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
778     for (const auto& node : fdef->node_def()) {
779       // Check if the op has a kernel available for the current device.
780       if (!KernelDefAvailable(current_device_type, node)) {
781         *is_multi_device = true;
782         return Status::OK();
783       }
784       // If the op has a requested device, check if the requested device is
785       // compatible with the current device.
786       if (!node.device().empty()) {
787         DeviceNameUtils::ParsedName node_device_name;
788         if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
789           return errors::InvalidArgument("Failed to parse device name: ",
790                                          node.device());
791         }
792         if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
793                                                     node_device_name)) {
794           *is_multi_device = true;
795           return Status::OK();
796         }
797       }
798     }
799   }
800 
801   *is_multi_device = false;
802   return Status::OK();
803 }
804 
805 /* static */
Create(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device,std::unique_ptr<InstantiatedCapturedFunction> * out_function)806 Status InstantiatedCapturedFunction::Create(
807     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
808     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
809     CapturedFunction* captured_func, bool is_multi_device,
810     std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
811   out_function->reset(new InstantiatedCapturedFunction(
812       lib, f_handle, ret_types, runner, captured_func, is_multi_device));
813   return Status::OK();
814 }
815 
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device)816 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
817     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
818     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
819     CapturedFunction* captured_func, bool is_multi_device)
820     : lib_(lib),
821       f_handle_(f_handle),
822       ret_types_(std::move(ret_types)),
823       captured_runner_(std::move(runner)),
824       captured_func_(captured_func),
825       is_multi_device_(is_multi_device) {}
826 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const827 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
828                                          std::vector<Tensor>&& args,
829                                          std::vector<Tensor>* rets) const {
830   return Run(ctx, std::move(args), rets, /*node=*/nullptr);
831 }
832 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const833 Status InstantiatedCapturedFunction::Run(
834     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
835     const std::shared_ptr<model::Node>& node) const {
836   auto& info = captured_func_->short_circuit_info();
837   if (!info.indices.empty()) {
838     return RunShortCircuit(info, std::move(args), captured_func_, rets);
839   }
840 
841   FunctionLibraryRuntime::Options f_opts;
842   ScopedStepContainer step_container(
843       f_opts.step_id, [this](const string& name) {
844         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
845       });
846   f_opts.step_container = &step_container;
847   f_opts.runner = ctx->runner();
848   f_opts.create_rendezvous = ShouldCreateRendezvous();
849   CancellationManager cancellation_manager(ctx->cancellation_manager());
850   f_opts.cancellation_manager = &cancellation_manager;
851 
852   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
853   if (node || ctx->stats_aggregator()) {
854     stats_collector = std::make_shared<SimpleStepStatsCollector>();
855   }
856   const bool collect_usage =
857       node && ctx->model() && ctx->model()->collect_resource_usage();
858   f_opts.stats_collector = stats_collector.get();
859 
860   OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
861                            ret_types_);
862   profiler::TraceMe activity(
863       [&] {
864         return absl::StrCat(
865             "InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
866       },
867       profiler::TraceMeLevel::kInfo);
868   if (node) {
869     // Resource usage for function execution is gathered from the executor.
870     // TODO(jsimsa): Factor out common code for Run, RunAsync, and
871     // RunWithBorrowedArguments
872     if (collect_usage) node->record_stop(EnvTime::NowNanos());
873     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
874     if (ctx->stats_aggregator()) {
875       string prefix_with_func_name = strings::StrCat(
876           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
877       ctx->stats_aggregator()->AddToHistogram(
878           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
879           {static_cast<float>(stats_collector->processing_time())},
880           node->num_elements());
881     }
882     node->add_processing_time(stats_collector->processing_time());
883     if (collect_usage) node->record_start(EnvTime::NowNanos());
884   } else {
885     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
886   }
887   return frame.ConsumeRetvals(rets);
888 }
889 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * ret) const890 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
891     IteratorContext* ctx, const std::vector<Tensor>& args,
892     std::vector<Tensor>* ret) const {
893   return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
894 }
895 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const896 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
897     IteratorContext* ctx, const std::vector<Tensor>& args,
898     std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
899   auto& info = captured_func_->short_circuit_info();
900   if (!info.indices.empty()) {
901     return RunShortCircuit(info, args, captured_func_, rets);
902   }
903 
904   FunctionLibraryRuntime::Options f_opts;
905   ScopedStepContainer step_container(
906       f_opts.step_id, [this](const string& name) {
907         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
908       });
909   f_opts.step_container = &step_container;
910   f_opts.runner = ctx->runner();
911   f_opts.create_rendezvous = ShouldCreateRendezvous();
912   CancellationManager cancellation_manager(ctx->cancellation_manager());
913   f_opts.cancellation_manager = &cancellation_manager;
914 
915   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
916   if (node || ctx->stats_aggregator()) {
917     stats_collector = std::make_shared<SimpleStepStatsCollector>();
918   }
919   const bool collect_usage =
920       node && ctx->model() && ctx->model()->collect_resource_usage();
921   f_opts.stats_collector = stats_collector.get();
922 
923   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
924                               ret_types_);
925   profiler::TraceMe activity(
926       [&] {
927         return absl::StrCat(
928             "InstantiatedCapturedFunction::RunWithBorrowedArgs#id=",
929             f_opts.step_id, "#");
930       },
931       profiler::TraceMeLevel::kInfo);
932   if (node) {
933     // Resource usage for function execution is gathered from the executor.
934     if (collect_usage) node->record_stop(EnvTime::NowNanos());
935     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
936     if (ctx->stats_aggregator()) {
937       string prefix_with_func_name = strings::StrCat(
938           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
939       ctx->stats_aggregator()->AddToHistogram(
940           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
941           {static_cast<float>(stats_collector->processing_time())},
942           node->num_elements());
943     }
944     node->add_processing_time(stats_collector->processing_time());
945     if (collect_usage) node->record_start(EnvTime::NowNanos());
946   } else {
947     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
948   }
949   return frame.ConsumeRetvals(rets);
950 }
951 
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)952 Status InstantiatedCapturedFunction::RunInstantiated(
953     const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
954   auto& info = captured_func_->short_circuit_info();
955   if (!info.indices.empty()) {
956     return RunShortCircuit(info, args, captured_func_, rets);
957   }
958 
959   FunctionLibraryRuntime::Options f_opts;
960   ScopedStepContainer step_container(
961       f_opts.step_id, [this](const string& name) {
962         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
963       });
964   f_opts.step_container = &step_container;
965   f_opts.runner = &captured_runner_;
966   f_opts.create_rendezvous = ShouldCreateRendezvous();
967   CancellationManager cancellation_manager;
968   f_opts.cancellation_manager = &cancellation_manager;
969 
970   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
971                               ret_types_);
972   profiler::TraceMe activity(
973       [&] {
974         return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
975                             f_opts.step_id, "#");
976       },
977       profiler::TraceMeLevel::kInfo);
978   TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
979   return frame.ConsumeRetvals(rets);
980 }
981 
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const std::shared_ptr<model::Node> & node) const982 void InstantiatedCapturedFunction::RunAsync(
983     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
984     FunctionLibraryRuntime::DoneCallback done,
985     const std::shared_ptr<model::Node>& node) const {
986   auto& info = captured_func_->short_circuit_info();
987   if (!info.indices.empty()) {
988     // Run the `done` callback on a threadpool thread, because it will
989     // potentially do a non-trivial amount of (e.g. copying) work, and we may
990     // want to run that concurrently with the next invocation.
991     Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
992     (*ctx->runner())(
993         std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
994                   std::move(done)));
995     return;
996   }
997 
998   // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
999   // be deleted before `done` is called. Take care not to capture `ctx` in any
1000   // code that may execute asynchronously in this function.
1001   OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
1002       std::move(args), &captured_func_->captured_inputs(), ret_types_);
1003 
1004   FunctionLibraryRuntime::Options f_opts;
1005   ResourceMgr* resource_mgr = lib_->device()->resource_manager();
1006   ScopedStepContainer* step_container = new ScopedStepContainer(
1007       f_opts.step_id, [resource_mgr](const string& name) {
1008         resource_mgr->Cleanup(name).IgnoreError();
1009       });
1010   f_opts.step_container = step_container;
1011   f_opts.runner = ctx->runner();
1012   f_opts.create_rendezvous = ShouldCreateRendezvous();
1013   auto cancellation_manager =
1014       absl::make_unique<CancellationManager>(ctx->cancellation_manager());
1015   f_opts.cancellation_manager = cancellation_manager.get();
1016 
1017   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
1018   if (node || ctx->stats_aggregator()) {
1019     stats_collector = std::make_shared<SimpleStepStatsCollector>();
1020   }
1021   const bool collect_usage =
1022       node && ctx->model() && ctx->model()->collect_resource_usage();
1023   f_opts.stats_collector = stats_collector.get();
1024 
1025   // Transfer ownership of the cancellation manager to `callback`.
1026   CancellationManager* raw_cancellation_manager =
1027       cancellation_manager.release();
1028   auto callback = std::bind(
1029       [this, rets, step_container, raw_cancellation_manager, frame, node,
1030        collect_usage](
1031           const FunctionLibraryRuntime::DoneCallback& done,
1032           IteratorContext* ctx,
1033           const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
1034           // Begin unbound arguments.
1035           Status s) {
1036         delete step_container;
1037         delete raw_cancellation_manager;
1038         if (s.ok()) {
1039           s = frame->ConsumeRetvals(rets);
1040         }
1041         delete frame;
1042         if (node) {
1043           // TODO(b/129085499) Utilize the `node_name` which would be unique
1044           // than the prefix for the function execution time statistics.
1045           // prefix_with_func_name would then be node_name + func_name.
1046           if (ctx->stats_aggregator()) {
1047             string prefix_with_func_name =
1048                 strings::StrCat(node->name(), stats_utils::kDelimiter,
1049                                 captured_func_->func().name());
1050             ctx->stats_aggregator()->AddToHistogram(
1051                 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
1052                 {static_cast<float>(stats_collector->processing_time())},
1053                 node->num_elements());
1054           }
1055           node->add_processing_time(stats_collector->processing_time());
1056         }
1057         if (collect_usage) {
1058           node->record_start(EnvTime::NowNanos());
1059         }
1060         done(s);
1061         if (collect_usage) {
1062           node->record_stop(EnvTime::NowNanos());
1063         }
1064       },
1065       std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1066 
1067   profiler::TraceMe activity(
1068       [&] {
1069         return absl::StrCat(
1070             "InstantiatedCapturedFunction::RunAsync#id=", f_opts.step_id, "#");
1071       },
1072       profiler::TraceMeLevel::kInfo);
1073   // Stop the usage collection before calling `Run()` because `callback` may
1074   // be executed synchronously, and so the `node->record_start()` call within
1075   // `callback` would violate nesting.
1076   if (collect_usage) node->record_stop(EnvTime::NowNanos());
1077   lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1078   if (collect_usage) node->record_start(EnvTime::NowNanos());
1079 }
1080 
ShouldCreateRendezvous() const1081 bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1082   // Rendezvous should only be created by the FLR for non-CPU single-device
1083   // functions. For multi-device functions the appropriate rendezvous will be
1084   // created by the process FLR.
1085   return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1086 }
1087 
1088 }  // namespace data
1089 }  // namespace tensorflow
1090