1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/captured_function.h"
16
17 #include <utility>
18
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/step_stats_collector.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function_handle_cache.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/kernels/data/dataset_utils.h"
29 #include "tensorflow/core/kernels/data/stats_utils.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/lib/random/random.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/notification.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41 #endif // !IS_MOBILE_PLATFORM
42
43 namespace tensorflow {
44 namespace data {
45 namespace {
46
47 const char kDataServiceDataset[] = "DataServiceDataset";
48
49 // Simplistic implementation of the `StepStatsCollectorInterface` that only
50 // cares about collecting the CPU time needed to execute a captured function.
51 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
52 public:
IncrementProcessingTime(int64 delta)53 void IncrementProcessingTime(int64 delta) {
54 mutex_lock l(mu_);
55 processing_time_ += delta;
56 }
57
CreateNodeExecStats(const NodeDef * node)58 NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
59 return new SimpleNodeExecStats(this);
60 }
61
ReportAllocsOnResourceExhausted(const string & err)62 string ReportAllocsOnResourceExhausted(const string& err) override {
63 return "";
64 }
65
processing_time()66 int64 processing_time() {
67 tf_shared_lock l(mu_);
68 return processing_time_;
69 }
70
71 private:
72 class SimpleNodeExecStats : public NodeExecStatsInterface {
73 public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)74 explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
75 : step_stats_collector_(step_stats_collector) {}
76
Done(const string & device)77 void Done(const string& device) override {
78 step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
79 start_time_ns_);
80 delete this;
81 }
82
RecordExecutorStarted()83 void RecordExecutorStarted() override {
84 start_time_ns_ = absl::GetCurrentTimeNanos();
85 }
86
RecordComputeStarted()87 void RecordComputeStarted() override {}
88
RecordComputeEnded()89 void RecordComputeEnded() override {}
90
RecordExecutorEnded()91 void RecordExecutorEnded() override {
92 end_time_ns_ = absl::GetCurrentTimeNanos();
93 }
94
TrackAllocations() const95 bool TrackAllocations() const override { return false; }
96
SetMemory(OpKernelContext * ctx)97 void SetMemory(OpKernelContext* ctx) override {}
98
SetOutput(int slot,const Tensor * tensor)99 void SetOutput(int slot, const Tensor* tensor) override {}
100
SetScheduled(int64 nanos)101 void SetScheduled(int64 nanos) override {}
102
103 private:
104 int64 start_time_ns_ = 0;
105 int64 end_time_ns_ = 0;
106 SimpleStepStatsCollector* step_stats_collector_; // Not owned.
107 };
108
109 mutex mu_;
110 int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
111 };
112
GetCapturedInput(const CapturedFunction * const func,int index,const Tensor ** out)113 Status GetCapturedInput(const CapturedFunction* const func, int index,
114 const Tensor** out) {
115 if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
116 return errors::OutOfRange(
117 "Out of range access to captured inputs for function ",
118 func->func().name(), ". Index: ", index,
119 ". Num captured inputs: ", func->captured_inputs().size());
120 }
121 *out = &func->captured_inputs()[index];
122 return Status::OK();
123 }
124
RunShortCircuit(const ShortCircuitInfo & info,const std::vector<Tensor> & args,const CapturedFunction * const func,std::vector<Tensor> * rets)125 Status RunShortCircuit(const ShortCircuitInfo& info,
126 const std::vector<Tensor>& args,
127 const CapturedFunction* const func,
128 std::vector<Tensor>* rets) {
129 VLOG(3) << "Running function " << func->func().name() << " short circuit";
130 const int num_args = args.size();
131 rets->reserve(info.indices.size());
132 for (size_t i = 0; i < info.indices.size(); ++i) {
133 if (info.indices[i] < num_args) {
134 rets->push_back(args[info.indices[i]]);
135 } else {
136 const Tensor* captured_input;
137 TF_RETURN_IF_ERROR(
138 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
139 rets->push_back(*captured_input);
140 }
141 }
142 return Status::OK();
143 }
144
RunShortCircuit(const ShortCircuitInfo & info,std::vector<Tensor> && args,const CapturedFunction * const func,std::vector<Tensor> * rets)145 Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
146 const CapturedFunction* const func,
147 std::vector<Tensor>* rets) {
148 VLOG(3) << "Running function " << func->func().name() << " short circuit";
149 const int num_args = args.size();
150 rets->reserve(info.indices.size());
151 for (size_t i = 0; i < info.indices.size(); ++i) {
152 if (info.indices[i] < num_args) {
153 if (info.can_move[i]) {
154 rets->push_back(std::move(args[info.indices[i]]));
155 } else {
156 rets->push_back(args[info.indices[i]]);
157 }
158 } else {
159 const Tensor* captured_input;
160 TF_RETURN_IF_ERROR(
161 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
162 rets->push_back(*captured_input);
163 }
164 }
165 return Status::OK();
166 }
167
CreateShortCircuitInfo(OpKernelConstruction * ctx,const NameAttrList & func,ShortCircuitInfo * info)168 Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
169 const NameAttrList& func,
170 ShortCircuitInfo* info) {
171 auto& indices = info->indices;
172
173 FunctionLibraryRuntime::Handle fn_handle;
174 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
175 func.name(), AttrSlice(&func.attr()), &fn_handle));
176 auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
177 Status s = ctx->function_library()->ReleaseHandle(fn_handle);
178 if (!s.ok()) {
179 LOG(WARNING) << "Failed to release handle: " << s.error_message();
180 }
181 });
182
183 // If the function contains any stateful operations, we conservatively execute
184 // the entire function.
185 if (ctx->function_library()->IsStateful(func.name())) {
186 return Status::OK();
187 }
188
189 const FunctionBody* fn_body =
190 ctx->function_library()->GetFunctionBody(fn_handle);
191 indices.resize(fn_body->ret_nodes.size());
192
193 for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
194 Node* ret_node = fn_body->ret_nodes[i];
195 Node* ret_input_node;
196 TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
197
198 while (ret_input_node->def().op() == "Identity") {
199 TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
200 }
201
202 if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
203 TF_RETURN_IF_ERROR(
204 GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
205 } else {
206 indices.clear();
207 break;
208 }
209 }
210
211 // Compute the `can_move` vector.
212 if (!indices.empty()) {
213 auto& can_move = info->can_move;
214 std::map<int, int> last_use;
215 for (size_t i = 0; i < indices.size(); ++i) {
216 last_use[indices[i]] = i;
217 }
218 can_move.resize(indices.size());
219 for (int i = 0, end = indices.size(); i < end; ++i) {
220 can_move[i] = last_use[indices[i]] == i;
221 }
222 }
223
224 return Status::OK();
225 }
226
CreateFunctionLibraryDefinition(const FunctionLibraryDefinition * lib_def,const string & func_name,std::unique_ptr<FunctionLibraryDefinition> * result)227 Status CreateFunctionLibraryDefinition(
228 const FunctionLibraryDefinition* lib_def, const string& func_name,
229 std::unique_ptr<FunctionLibraryDefinition>* result) {
230 DCHECK(lib_def != nullptr);
231 const FunctionDef* fdef = lib_def->Find(func_name);
232 if (TF_PREDICT_FALSE(fdef == nullptr)) {
233 return errors::FailedPrecondition(strings::StrCat(
234 "Could not find required function definition ", func_name));
235 }
236 *result = absl::make_unique<FunctionLibraryDefinition>(
237 lib_def->ReachableDefinitions(*fdef));
238 return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
239 }
240
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)241 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
242 const FunctionDef& function_def) {
243 if (!function_def.signature().is_stateful()) {
244 return Status::OK();
245 }
246
247 for (const NodeDef& node_def : function_def.node_def()) {
248 TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
249 }
250 return Status::OK();
251 }
252
253 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
254 // allowlist source dataset ops which have been marked stateful due to
255 // b/65524810. Also looks up the `op_def->name` in the global
256 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)257 bool IsOpAllowlisted(const OpDef* op_def) {
258 return (op_def->output_arg_size() == 1 &&
259 op_def->output_arg(0).type() == DT_VARIANT &&
260 (absl::EndsWith(op_def->name(), "Dataset") ||
261 absl::EndsWith(op_def->name(), "DatasetV2"))) ||
262 AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
263 }
264
LookupFunction(const FunctionLibraryDefinition & lib_def,const string & name,const FunctionDef ** fdef)265 Status LookupFunction(const FunctionLibraryDefinition& lib_def,
266 const string& name, const FunctionDef** fdef) {
267 *fdef = lib_def.Find(name);
268 if (*fdef == nullptr) {
269 return errors::InvalidArgument(
270 "Failed to find function ", name,
271 " in function library: ", lib_def.ToProto().DebugString());
272 }
273 return Status::OK();
274 }
275
276 class CallFrameBase : public CallFrameInterface {
277 public:
CallFrameBase(DataTypeSlice ret_types)278 explicit CallFrameBase(DataTypeSlice ret_types)
279 : ret_types_(ret_types), retvals_(ret_types.size()) {}
280
281 // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)282 Status ConsumeRetvals(std::vector<Tensor>* retvals) {
283 retvals->reserve(retvals_.size());
284 int i = 0;
285 for (auto&& val : retvals_) {
286 if (!val) {
287 return errors::Internal("No return value for index ", i, ".");
288 }
289 retvals->emplace_back(std::move(val.value()));
290 ++i;
291 }
292 return Status::OK();
293 }
294
num_retvals() const295 size_t num_retvals() const override { return retvals_.size(); }
296
297 // Callee methods.
SetRetval(int index,const Tensor & val)298 Status SetRetval(int index, const Tensor& val) override {
299 const int retvals_size = retvals_.size();
300 if (index < retvals_size && val.dtype() == ret_types_[index] &&
301 !retvals_[index]) {
302 retvals_[index] = val;
303 return Status::OK();
304 } else if (index >= retvals_size) {
305 return errors::InvalidArgument("Return value ", index,
306 " is out of range.");
307 } else if (val.dtype() != ret_types_[index]) {
308 return errors::InvalidArgument("Expected type ",
309 DataTypeString(ret_types_[index]),
310 " for return value ", index, " but got ",
311 DataTypeString(val.dtype()), ".");
312 } else {
313 return errors::Internal("Attempted to set return value ", index,
314 " more than once.");
315 }
316 }
317
318 private:
319 DataTypeSlice ret_types_;
320 std::vector<gtl::optional<Tensor>> retvals_;
321 TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
322 };
323
324 class OwnedArgsCallFrame : public CallFrameBase {
325 public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)326 OwnedArgsCallFrame(std::vector<Tensor>&& args,
327 const std::vector<Tensor>* captured_inputs,
328 DataTypeSlice ret_types)
329 : CallFrameBase(ret_types),
330 args_(std::move(args)),
331 captured_inputs_(captured_inputs) {}
332
num_args() const333 size_t num_args() const override {
334 return args_.size() + captured_inputs_->size();
335 }
336
337 // Callee methods.
GetArg(int index,const Tensor ** val)338 Status GetArg(int index, const Tensor** val) override {
339 const int args_size = args_.size();
340 const int captured_inputs_size = captured_inputs_->size();
341 if (index < args_size) {
342 *val = &args_[index];
343 return Status::OK();
344 } else if (index < args_size + captured_inputs_size) {
345 *val = &(*captured_inputs_)[index - args_.size()];
346 return Status::OK();
347 } else {
348 return errors::InvalidArgument("Argument ", index, " is out of range.");
349 }
350 }
351
352 // Since we own the argument tensors in `args_`, we can implement
353 // `ConsumeArg()` for those arguments.
ConsumeArg(int index,Tensor * val)354 void ConsumeArg(int index, Tensor* val) override {
355 DCHECK_GE(index, 0);
356 DCHECK_LT(index, args_.size());
357 *val = std::move(args_[index]);
358 }
CanConsumeArg(int index) const359 bool CanConsumeArg(int index) const override {
360 return index >= 0 && index < static_cast<int>(args_.size());
361 }
362
363 private:
364 std::vector<Tensor> args_;
365 const std::vector<Tensor>* const captured_inputs_; // Not owned.
366 };
367
368 class BorrowedArgsCallFrame : public CallFrameBase {
369 public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)370 BorrowedArgsCallFrame(const std::vector<Tensor>& args,
371 const std::vector<Tensor>* captured_inputs,
372 DataTypeSlice ret_types)
373 : CallFrameBase(ret_types),
374 args_(args),
375 captured_inputs_(captured_inputs) {}
376
num_args() const377 size_t num_args() const override {
378 return args_.size() + captured_inputs_->size();
379 }
380
381 // Callee methods.
GetArg(int index,const Tensor ** val)382 Status GetArg(int index, const Tensor** val) override {
383 const int args_size = args_.size();
384 const int captured_inputs_size = captured_inputs_->size();
385 if (index < args_size) {
386 *val = &args_[index];
387 return Status::OK();
388 } else if (index < args_size + captured_inputs_size) {
389 *val = &(*captured_inputs_)[index - args_size];
390 return Status::OK();
391 } else {
392 return errors::InvalidArgument("Argument ", index, " is out of range.");
393 }
394 }
395
396 private:
397 const std::vector<Tensor>& args_; // Not owned.
398 const std::vector<Tensor>* const captured_inputs_; // Not owned.
399 };
400
401 } // namespace
402
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)403 Status IsNodeStateful(const FunctionLibraryDefinition& library,
404 const NodeDef& node) {
405 const OpDef* op_def;
406
407 // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
408 // `LookUpOpDef` errors here.
409 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
410 IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
411 op_def->name() == "Assert") {
412 return Status::OK();
413 }
414
415 if (op_def->name() == "If") {
416 const FunctionDef* then_func =
417 library.Find(node.attr().at("then_branch").func().name());
418 const FunctionDef* else_func =
419 library.Find(node.attr().at("else_branch").func().name());
420 if (then_func != nullptr) {
421 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
422 }
423 if (else_func != nullptr) {
424 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
425 }
426 return Status::OK();
427 }
428
429 if (op_def->name() == "While") {
430 const FunctionDef* cond_func =
431 library.Find(node.attr().at("cond").func().name());
432 const FunctionDef* body_func =
433 library.Find(node.attr().at("body").func().name());
434 if (cond_func != nullptr) {
435 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
436 }
437 if (body_func != nullptr) {
438 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
439 }
440 return Status::OK();
441 }
442
443 return errors::FailedPrecondition(op_def->name(), " is stateful.");
444 }
445
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)446 Status MakeIteratorFromInputElement(
447 IteratorContext* ctx, const IteratorBase* parent,
448 const std::vector<Tensor>& input_element, int64 thread_index,
449 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
450 std::unique_ptr<IteratorBase>* out_iterator) {
451 return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
452 inst_captured_func, prefix, out_iterator,
453 /*node=*/nullptr);
454 }
455
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator,const std::shared_ptr<model::Node> & node)456 Status MakeIteratorFromInputElement(
457 IteratorContext* ctx, const IteratorBase* parent,
458 const std::vector<Tensor>& input_element, int64 thread_index,
459 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
460 std::unique_ptr<IteratorBase>* out_iterator,
461 const std::shared_ptr<model::Node>& node) {
462 std::vector<Tensor> return_values;
463
464 TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
465 ctx, input_element, &return_values, node));
466
467 if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
468 TensorShapeUtils::IsScalar(return_values[0].shape()))) {
469 return errors::InvalidArgument(
470 "Function must return a single scalar of dtype DT_VARIANT.");
471 }
472
473 // Retrieve the dataset that was created in `f`.
474 DatasetBase* returned_dataset;
475 TF_RETURN_IF_ERROR(
476 GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
477
478 // Create an iterator for the dataset that was returned by `f`.
479 std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
480 if (ctx->split_provider() == nullptr) {
481 return returned_dataset->MakeIterator(ctx, parent, iterator_prefix,
482 out_iterator);
483 }
484 // Strip out the split provider so that it doesn't apply to sub-iterators.
485 IteratorContext::Params params(ctx);
486 params.split_provider = nullptr;
487 return returned_dataset->MakeIterator(IteratorContext(std::move(params)),
488 parent, iterator_prefix, out_iterator);
489 }
490
491 /* static */
Create(OpKernelConstruction * ctx,const string & func_name,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)492 Status FunctionMetadata::Create(
493 OpKernelConstruction* ctx, const string& func_name, Params params,
494 std::shared_ptr<FunctionMetadata>* out_metadata) {
495 NameAttrList func;
496 TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
497 return Create(ctx, std::move(func), params, out_metadata);
498 }
499
Create(OpKernelConstruction * ctx,NameAttrList && func,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)500 Status FunctionMetadata::Create(
501 OpKernelConstruction* ctx, NameAttrList&& func, Params params,
502 std::shared_ptr<FunctionMetadata>* out_metadata) {
503 out_metadata->reset(new FunctionMetadata(std::move(func), params));
504 TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
505 ctx->function_library()->GetFunctionLibraryDefinition(),
506 (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
507 TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
508 ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
509 const FunctionDef* fdef;
510 TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
511 (*out_metadata)->func().name(), &fdef));
512
513 auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
514 if (attr != fdef->attr().end() && attr->second.b()) {
515 VLOG(1) << "Disabling multi-device execution for a function that uses the "
516 << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
517 (*out_metadata)->use_multi_device_function_ = false;
518 return Status::OK();
519 }
520 auto validate_arg = [](const OpDef::ArgDef& arg) {
521 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
522 VLOG(1) << "Disabling multi-device execution for a function with "
523 << "a vector argument " << arg.name() << ".";
524 return false;
525 }
526 return true;
527 };
528 for (const auto& arg : fdef->signature().input_arg()) {
529 if (!validate_arg(arg)) {
530 (*out_metadata)->use_multi_device_function_ = false;
531 return Status::OK();
532 }
533 }
534 for (const auto& arg : fdef->signature().output_arg()) {
535 if (!validate_arg(arg)) {
536 (*out_metadata)->use_multi_device_function_ = false;
537 return Status::OK();
538 }
539 }
540 return Status::OK();
541 }
542
543 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)544 Status CapturedFunction::Create(
545 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
546 const string& argument_name,
547 std::unique_ptr<CapturedFunction>* out_function) {
548 OpInputList inputs;
549 TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
550 std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
551 return Create(ctx, std::move(metadata), std::move(captured_inputs),
552 out_function);
553 }
554
555 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> && captured_inputs,std::unique_ptr<CapturedFunction> * out_function)556 Status CapturedFunction::Create(
557 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
558 std::vector<Tensor>&& captured_inputs,
559 std::unique_ptr<CapturedFunction>* out_function) {
560 *out_function = absl::WrapUnique(
561 new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
562 return Status::OK();
563 }
564
AddToGraph(SerializationContext * ctx,DatasetBase::DatasetGraphDefBuilder * b,std::vector<Node * > * other_arguments,DataTypeVector * other_arguments_types) const565 Status CapturedFunction::AddToGraph(
566 SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
567 std::vector<Node*>* other_arguments,
568 DataTypeVector* other_arguments_types) const {
569 other_arguments->reserve(captured_inputs_.size());
570 other_arguments_types->reserve(captured_inputs_.size());
571 for (const Tensor& t : captured_inputs_) {
572 Node* node;
573 DatasetBase* input;
574 Status s = GetDatasetFromVariantTensor(t, &input);
575 if (s.ok()) {
576 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
577 } else {
578 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
579 }
580 other_arguments->emplace_back(node);
581 other_arguments_types->emplace_back(t.dtype());
582 }
583 TF_RETURN_IF_ERROR(
584 b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
585 return Status::OK();
586 }
587
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)588 Status CapturedFunction::Instantiate(
589 IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
590 instantiated_captured_function) {
591 // The context's runtime will be used for all subsequent calls.
592 FunctionLibraryRuntime* lib = ctx->flr();
593 FunctionLibraryRuntime::InstantiateOptions inst_opts;
594 inst_opts.lib_def = metadata_->lib_def();
595 inst_opts.create_kernels_eagerly = true;
596 inst_opts.default_device_to_target = metadata_->use_default_device();
597 inst_opts.config_proto =
598 lib->config_proto() ? *lib->config_proto() : ConfigProto();
599 if (!metadata_->use_inter_op_parallelism()) {
600 inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
601 }
602 inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
603
604 // We infer the target device from the function library runtime.
605 DCHECK(lib->device() != nullptr);
606 inst_opts.target = lib->device()->name();
607
608 // Maps from a CompositeDevice name to underlying physical device names.
609 absl::flat_hash_map<string, std::vector<string>> composite_devices;
610
611 if (inst_opts.is_multi_device_function) {
612 // Compute devices of non-captured inputs.
613 //
614 // We infer the number of non-captured inputs by subtracting the number
615 // of captured inputs from the number of input arguments and we infer the
616 // input devices from the function library runtime.
617 const FunctionDef* fdef;
618 TF_RETURN_IF_ERROR(
619 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
620 size_t num_non_captured_inputs =
621 fdef->signature().input_arg_size() - captured_inputs_.size();
622 for (size_t i = 0; i < num_non_captured_inputs; ++i) {
623 inst_opts.input_devices.push_back(inst_opts.target);
624 }
625 // Compute devices of captured inputs.
626 // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
627 Device* cpu_device;
628 TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
629 std::unordered_map<int, DtypeAndPartialTensorShape>&
630 input_resource_variable_dtypes_and_shapes =
631 inst_opts.input_resource_dtypes_and_shapes;
632 for (size_t i = 0; i < captured_inputs_.size(); ++i) {
633 const auto& input = captured_inputs_[i];
634 DataType dtype = input.dtype();
635 if (dtype == DT_RESOURCE) {
636 const auto& handles = input.flat<ResourceHandle>();
637 const ResourceHandle& handle0 = handles(0);
638 string composite_device;
639 auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
640 if (iter != fdef->arg_attr().end()) {
641 auto arg_attr = iter->second.attr().find("_composite_device");
642 if (arg_attr != iter->second.attr().end()) {
643 composite_device = arg_attr->second.s();
644 }
645 }
646 if (!composite_device.empty()) {
647 if (composite_devices.find(composite_device) ==
648 composite_devices.end()) {
649 for (int i = 0; i < handles.size(); ++i) {
650 composite_devices[composite_device].push_back(
651 handles(i).device());
652 }
653 }
654 inst_opts.input_devices.push_back(composite_device);
655 } else {
656 inst_opts.input_devices.push_back(handle0.device());
657 }
658 const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
659 // Set dtypes and shapes for resource variable inputs.
660 if (!dtypes_and_shapes.empty()) {
661 input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
662 i] =
663 dtypes_and_shapes.at(0);
664 }
665 } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
666 inst_opts.input_devices.push_back(cpu_device->name());
667 } else {
668 // Fall back to using the function library runtime device.
669 inst_opts.input_devices.push_back(inst_opts.target);
670 }
671 }
672
673 for (const auto& it : composite_devices) {
674 inst_opts.composite_devices[it.first] = &it.second;
675 }
676
677 for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
678 inst_opts.output_devices.push_back(inst_opts.target);
679 }
680
681 #if !defined(IS_MOBILE_PLATFORM)
682 grappler::GrapplerItem::OptimizationOptions optimization_options;
683 optimization_options.allow_pruning_stateful_and_dataset_ops = false;
684 ConfigProto config_proto = inst_opts.config_proto;
685 // Layout optimizations are excluded because they assume that ops without
686 // explicit device assignment will be placed on GPU (if available) but
687 // that's not the case for operations within tf.data functions.
688 config_proto.mutable_graph_options()
689 ->mutable_rewrite_options()
690 ->set_layout_optimizer(RewriterConfig::OFF);
691 // TODO(b/120437209): Re-enable constant folding.
692 config_proto.mutable_graph_options()
693 ->mutable_rewrite_options()
694 ->set_constant_folding(RewriterConfig::OFF);
695 inst_opts.optimize_graph_fn =
696 std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
697 std::placeholders::_2, std::placeholders::_3,
698 std::placeholders::_4, std::placeholders::_5,
699 std::move(config_proto), fdef->signature().name(),
700 std::move(optimization_options), std::placeholders::_6);
701 #endif // !IS_MOBILE_PLATFORM
702 }
703
704 FunctionLibraryRuntime::Handle f_handle;
705 TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
706 metadata_->func().name(), AttrSlice(&metadata_->func().attr()), inst_opts,
707 &f_handle));
708
709 DataTypeVector ret_types;
710 TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
711
712 bool is_multi_device;
713 TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
714 return InstantiatedCapturedFunction::Create(
715 lib, f_handle, std::move(ret_types), *ctx->runner(), this,
716 is_multi_device, instantiated_captured_function);
717 }
718
CheckExternalState() const719 Status CapturedFunction::CheckExternalState() const {
720 for (const auto& name : lib_def()->ListFunctionNames()) {
721 TF_RETURN_IF_ERROR(
722 IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
723 }
724 return Status::OK();
725 }
726
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> captured_inputs)727 CapturedFunction::CapturedFunction(
728 std::shared_ptr<const FunctionMetadata> metadata,
729 std::vector<Tensor> captured_inputs)
730 : metadata_(std::move(metadata)),
731 captured_inputs_(std::move(captured_inputs)) {}
732
IsMultiDevice(IteratorContext * ctx,bool * is_multi_device) const733 Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
734 bool* is_multi_device) const {
735 if (!metadata_->use_multi_device_function()) {
736 *is_multi_device = false;
737 return Status::OK();
738 }
739
740 const FunctionDef* fdef;
741 TF_RETURN_IF_ERROR(
742 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
743
744 Device* current_device = ctx->flr()->device();
745 DeviceType current_device_type(current_device->device_type());
746 DeviceNameUtils::ParsedName current_device_name;
747 if (!DeviceNameUtils::ParseFullName(current_device->name(),
748 ¤t_device_name)) {
749 return errors::InvalidArgument("Failed to parse device name: ",
750 current_device->name());
751 }
752
753 // Check if any of the captured inputs are placed on a device not compatible
754 // with the current device. For non-captured inputs, we assume they are placed
755 // on the current device.
756 for (const auto& input : captured_inputs_) {
757 DataType dtype = input.dtype();
758 if (dtype == DT_RESOURCE) {
759 const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
760 DeviceNameUtils::ParsedName resource_device_name;
761 if (!DeviceNameUtils::ParseFullName(handle.device(),
762 &resource_device_name)) {
763 return errors::InvalidArgument("Failed to parse device name: ",
764 handle.device());
765 }
766 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
767 resource_device_name)) {
768 *is_multi_device = true;
769 return Status::OK();
770 }
771 }
772 }
773
774 // Check if all ops could be placed on the current device.
775 for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
776 const FunctionDef* fdef;
777 TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
778 for (const auto& node : fdef->node_def()) {
779 // Check if the op has a kernel available for the current device.
780 if (!KernelDefAvailable(current_device_type, node)) {
781 *is_multi_device = true;
782 return Status::OK();
783 }
784 // If the op has a requested device, check if the requested device is
785 // compatible with the current device.
786 if (!node.device().empty()) {
787 DeviceNameUtils::ParsedName node_device_name;
788 if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
789 return errors::InvalidArgument("Failed to parse device name: ",
790 node.device());
791 }
792 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
793 node_device_name)) {
794 *is_multi_device = true;
795 return Status::OK();
796 }
797 }
798 }
799 }
800
801 *is_multi_device = false;
802 return Status::OK();
803 }
804
805 /* static */
Create(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device,std::unique_ptr<InstantiatedCapturedFunction> * out_function)806 Status InstantiatedCapturedFunction::Create(
807 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
808 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
809 CapturedFunction* captured_func, bool is_multi_device,
810 std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
811 out_function->reset(new InstantiatedCapturedFunction(
812 lib, f_handle, ret_types, runner, captured_func, is_multi_device));
813 return Status::OK();
814 }
815
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device)816 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
817 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
818 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
819 CapturedFunction* captured_func, bool is_multi_device)
820 : lib_(lib),
821 f_handle_(f_handle),
822 ret_types_(std::move(ret_types)),
823 captured_runner_(std::move(runner)),
824 captured_func_(captured_func),
825 is_multi_device_(is_multi_device) {}
826
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const827 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
828 std::vector<Tensor>&& args,
829 std::vector<Tensor>* rets) const {
830 return Run(ctx, std::move(args), rets, /*node=*/nullptr);
831 }
832
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const833 Status InstantiatedCapturedFunction::Run(
834 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
835 const std::shared_ptr<model::Node>& node) const {
836 auto& info = captured_func_->short_circuit_info();
837 if (!info.indices.empty()) {
838 return RunShortCircuit(info, std::move(args), captured_func_, rets);
839 }
840
841 FunctionLibraryRuntime::Options f_opts;
842 ScopedStepContainer step_container(
843 f_opts.step_id, [this](const string& name) {
844 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
845 });
846 f_opts.step_container = &step_container;
847 f_opts.runner = ctx->runner();
848 f_opts.create_rendezvous = ShouldCreateRendezvous();
849 CancellationManager cancellation_manager(ctx->cancellation_manager());
850 f_opts.cancellation_manager = &cancellation_manager;
851
852 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
853 if (node || ctx->stats_aggregator()) {
854 stats_collector = std::make_shared<SimpleStepStatsCollector>();
855 }
856 const bool collect_usage =
857 node && ctx->model() && ctx->model()->collect_resource_usage();
858 f_opts.stats_collector = stats_collector.get();
859
860 OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
861 ret_types_);
862 profiler::TraceMe activity(
863 [&] {
864 return absl::StrCat(
865 "InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
866 },
867 profiler::TraceMeLevel::kInfo);
868 if (node) {
869 // Resource usage for function execution is gathered from the executor.
870 // TODO(jsimsa): Factor out common code for Run, RunAsync, and
871 // RunWithBorrowedArguments
872 if (collect_usage) node->record_stop(EnvTime::NowNanos());
873 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
874 if (ctx->stats_aggregator()) {
875 string prefix_with_func_name = strings::StrCat(
876 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
877 ctx->stats_aggregator()->AddToHistogram(
878 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
879 {static_cast<float>(stats_collector->processing_time())},
880 node->num_elements());
881 }
882 node->add_processing_time(stats_collector->processing_time());
883 if (collect_usage) node->record_start(EnvTime::NowNanos());
884 } else {
885 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
886 }
887 return frame.ConsumeRetvals(rets);
888 }
889
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * ret) const890 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
891 IteratorContext* ctx, const std::vector<Tensor>& args,
892 std::vector<Tensor>* ret) const {
893 return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
894 }
895
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const896 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
897 IteratorContext* ctx, const std::vector<Tensor>& args,
898 std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
899 auto& info = captured_func_->short_circuit_info();
900 if (!info.indices.empty()) {
901 return RunShortCircuit(info, args, captured_func_, rets);
902 }
903
904 FunctionLibraryRuntime::Options f_opts;
905 ScopedStepContainer step_container(
906 f_opts.step_id, [this](const string& name) {
907 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
908 });
909 f_opts.step_container = &step_container;
910 f_opts.runner = ctx->runner();
911 f_opts.create_rendezvous = ShouldCreateRendezvous();
912 CancellationManager cancellation_manager(ctx->cancellation_manager());
913 f_opts.cancellation_manager = &cancellation_manager;
914
915 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
916 if (node || ctx->stats_aggregator()) {
917 stats_collector = std::make_shared<SimpleStepStatsCollector>();
918 }
919 const bool collect_usage =
920 node && ctx->model() && ctx->model()->collect_resource_usage();
921 f_opts.stats_collector = stats_collector.get();
922
923 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
924 ret_types_);
925 profiler::TraceMe activity(
926 [&] {
927 return absl::StrCat(
928 "InstantiatedCapturedFunction::RunWithBorrowedArgs#id=",
929 f_opts.step_id, "#");
930 },
931 profiler::TraceMeLevel::kInfo);
932 if (node) {
933 // Resource usage for function execution is gathered from the executor.
934 if (collect_usage) node->record_stop(EnvTime::NowNanos());
935 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
936 if (ctx->stats_aggregator()) {
937 string prefix_with_func_name = strings::StrCat(
938 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
939 ctx->stats_aggregator()->AddToHistogram(
940 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
941 {static_cast<float>(stats_collector->processing_time())},
942 node->num_elements());
943 }
944 node->add_processing_time(stats_collector->processing_time());
945 if (collect_usage) node->record_start(EnvTime::NowNanos());
946 } else {
947 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
948 }
949 return frame.ConsumeRetvals(rets);
950 }
951
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)952 Status InstantiatedCapturedFunction::RunInstantiated(
953 const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
954 auto& info = captured_func_->short_circuit_info();
955 if (!info.indices.empty()) {
956 return RunShortCircuit(info, args, captured_func_, rets);
957 }
958
959 FunctionLibraryRuntime::Options f_opts;
960 ScopedStepContainer step_container(
961 f_opts.step_id, [this](const string& name) {
962 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
963 });
964 f_opts.step_container = &step_container;
965 f_opts.runner = &captured_runner_;
966 f_opts.create_rendezvous = ShouldCreateRendezvous();
967 CancellationManager cancellation_manager;
968 f_opts.cancellation_manager = &cancellation_manager;
969
970 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
971 ret_types_);
972 profiler::TraceMe activity(
973 [&] {
974 return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
975 f_opts.step_id, "#");
976 },
977 profiler::TraceMeLevel::kInfo);
978 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
979 return frame.ConsumeRetvals(rets);
980 }
981
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const std::shared_ptr<model::Node> & node) const982 void InstantiatedCapturedFunction::RunAsync(
983 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
984 FunctionLibraryRuntime::DoneCallback done,
985 const std::shared_ptr<model::Node>& node) const {
986 auto& info = captured_func_->short_circuit_info();
987 if (!info.indices.empty()) {
988 // Run the `done` callback on a threadpool thread, because it will
989 // potentially do a non-trivial amount of (e.g. copying) work, and we may
990 // want to run that concurrently with the next invocation.
991 Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
992 (*ctx->runner())(
993 std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
994 std::move(done)));
995 return;
996 }
997
998 // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
999 // be deleted before `done` is called. Take care not to capture `ctx` in any
1000 // code that may execute asynchronously in this function.
1001 OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
1002 std::move(args), &captured_func_->captured_inputs(), ret_types_);
1003
1004 FunctionLibraryRuntime::Options f_opts;
1005 ResourceMgr* resource_mgr = lib_->device()->resource_manager();
1006 ScopedStepContainer* step_container = new ScopedStepContainer(
1007 f_opts.step_id, [resource_mgr](const string& name) {
1008 resource_mgr->Cleanup(name).IgnoreError();
1009 });
1010 f_opts.step_container = step_container;
1011 f_opts.runner = ctx->runner();
1012 f_opts.create_rendezvous = ShouldCreateRendezvous();
1013 auto cancellation_manager =
1014 absl::make_unique<CancellationManager>(ctx->cancellation_manager());
1015 f_opts.cancellation_manager = cancellation_manager.get();
1016
1017 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
1018 if (node || ctx->stats_aggregator()) {
1019 stats_collector = std::make_shared<SimpleStepStatsCollector>();
1020 }
1021 const bool collect_usage =
1022 node && ctx->model() && ctx->model()->collect_resource_usage();
1023 f_opts.stats_collector = stats_collector.get();
1024
1025 // Transfer ownership of the cancellation manager to `callback`.
1026 CancellationManager* raw_cancellation_manager =
1027 cancellation_manager.release();
1028 auto callback = std::bind(
1029 [this, rets, step_container, raw_cancellation_manager, frame, node,
1030 collect_usage](
1031 const FunctionLibraryRuntime::DoneCallback& done,
1032 IteratorContext* ctx,
1033 const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
1034 // Begin unbound arguments.
1035 Status s) {
1036 delete step_container;
1037 delete raw_cancellation_manager;
1038 if (s.ok()) {
1039 s = frame->ConsumeRetvals(rets);
1040 }
1041 delete frame;
1042 if (node) {
1043 // TODO(b/129085499) Utilize the `node_name` which would be unique
1044 // than the prefix for the function execution time statistics.
1045 // prefix_with_func_name would then be node_name + func_name.
1046 if (ctx->stats_aggregator()) {
1047 string prefix_with_func_name =
1048 strings::StrCat(node->name(), stats_utils::kDelimiter,
1049 captured_func_->func().name());
1050 ctx->stats_aggregator()->AddToHistogram(
1051 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
1052 {static_cast<float>(stats_collector->processing_time())},
1053 node->num_elements());
1054 }
1055 node->add_processing_time(stats_collector->processing_time());
1056 }
1057 if (collect_usage) {
1058 node->record_start(EnvTime::NowNanos());
1059 }
1060 done(s);
1061 if (collect_usage) {
1062 node->record_stop(EnvTime::NowNanos());
1063 }
1064 },
1065 std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1066
1067 profiler::TraceMe activity(
1068 [&] {
1069 return absl::StrCat(
1070 "InstantiatedCapturedFunction::RunAsync#id=", f_opts.step_id, "#");
1071 },
1072 profiler::TraceMeLevel::kInfo);
1073 // Stop the usage collection before calling `Run()` because `callback` may
1074 // be executed synchronously, and so the `node->record_start()` call within
1075 // `callback` would violate nesting.
1076 if (collect_usage) node->record_stop(EnvTime::NowNanos());
1077 lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1078 if (collect_usage) node->record_start(EnvTime::NowNanos());
1079 }
1080
ShouldCreateRendezvous() const1081 bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1082 // Rendezvous should only be created by the FLR for non-CPU single-device
1083 // functions. For multi-device functions the appropriate rendezvous will be
1084 // created by the process FLR.
1085 return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1086 }
1087
1088 } // namespace data
1089 } // namespace tensorflow
1090