1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/captured_function.h"
16
17 #include <utility>
18
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/step_stats_collector.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/stats_utils.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function_handle_cache.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/lib/random/random.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/notification.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41 #endif // !IS_MOBILE_PLATFORM
42
43 namespace tensorflow {
44 namespace data {
45 namespace {
46
47 // Simplistic implementation of the `StepStatsCollectorInterface` that only
48 // cares about collecting the CPU time needed to execute a captured function.
49 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
50 public:
IncrementProcessingTime(int64_t delta)51 void IncrementProcessingTime(int64_t delta) {
52 mutex_lock l(mu_);
53 processing_time_ += delta;
54 }
55
CreateNodeExecStats(const NodeDef * node)56 NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
57 return new SimpleNodeExecStats(this);
58 }
59
ReportAllocsOnResourceExhausted(const string & err)60 string ReportAllocsOnResourceExhausted(const string& err) override {
61 return "";
62 }
63
processing_time()64 int64 processing_time() {
65 tf_shared_lock l(mu_);
66 return processing_time_;
67 }
68
69 private:
70 class SimpleNodeExecStats : public NodeExecStatsInterface {
71 public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)72 explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
73 : step_stats_collector_(step_stats_collector) {}
74
Done(const string & device)75 void Done(const string& device) override {
76 step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
77 start_time_ns_);
78 delete this;
79 }
80
RecordExecutorStarted()81 void RecordExecutorStarted() override {
82 start_time_ns_ = absl::GetCurrentTimeNanos();
83 }
84
RecordComputeStarted()85 void RecordComputeStarted() override {}
86
RecordComputeEnded()87 void RecordComputeEnded() override {}
88
RecordExecutorEnded()89 void RecordExecutorEnded() override {
90 end_time_ns_ = absl::GetCurrentTimeNanos();
91 }
92
TrackAllocations() const93 bool TrackAllocations() const override { return false; }
94
SetMemory(OpKernelContext * ctx)95 void SetMemory(OpKernelContext* ctx) override {}
96
SetOutput(int slot,const Tensor * tensor)97 void SetOutput(int slot, const Tensor* tensor) override {}
98
SetScheduled(int64_t nanos)99 void SetScheduled(int64_t nanos) override {}
100
101 private:
102 int64 start_time_ns_ = 0;
103 int64 end_time_ns_ = 0;
104 SimpleStepStatsCollector* step_stats_collector_; // Not owned.
105 };
106
107 mutex mu_;
108 int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
109 };
110
GetCapturedInput(const CapturedFunction * const func,int index,const Tensor ** out)111 Status GetCapturedInput(const CapturedFunction* const func, int index,
112 const Tensor** out) {
113 if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
114 return errors::OutOfRange(
115 "Out of range access to captured inputs for function ",
116 func->func().name(), ". Index: ", index,
117 ". Num captured inputs: ", func->captured_inputs().size());
118 }
119 *out = &func->captured_inputs()[index];
120 return Status::OK();
121 }
122
RunShortCircuit(const ShortCircuitInfo & info,const std::vector<Tensor> & args,const CapturedFunction * const func,std::vector<Tensor> * rets)123 Status RunShortCircuit(const ShortCircuitInfo& info,
124 const std::vector<Tensor>& args,
125 const CapturedFunction* const func,
126 std::vector<Tensor>* rets) {
127 VLOG(3) << "Running function " << func->func().name() << " short circuit";
128 const int num_args = args.size();
129 rets->reserve(info.indices.size());
130 for (size_t i = 0; i < info.indices.size(); ++i) {
131 if (info.indices[i] < num_args) {
132 rets->push_back(args[info.indices[i]]);
133 } else {
134 const Tensor* captured_input;
135 TF_RETURN_IF_ERROR(
136 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
137 rets->push_back(*captured_input);
138 }
139 }
140 return Status::OK();
141 }
142
RunShortCircuit(const ShortCircuitInfo & info,std::vector<Tensor> && args,const CapturedFunction * const func,std::vector<Tensor> * rets)143 Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
144 const CapturedFunction* const func,
145 std::vector<Tensor>* rets) {
146 VLOG(3) << "Running function " << func->func().name() << " short circuit";
147 const int num_args = args.size();
148 rets->reserve(info.indices.size());
149 for (size_t i = 0; i < info.indices.size(); ++i) {
150 if (info.indices[i] < num_args) {
151 if (info.can_move[i]) {
152 rets->push_back(std::move(args[info.indices[i]]));
153 } else {
154 rets->push_back(args[info.indices[i]]);
155 }
156 } else {
157 const Tensor* captured_input;
158 TF_RETURN_IF_ERROR(
159 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
160 rets->push_back(*captured_input);
161 }
162 }
163 return Status::OK();
164 }
165
CreateShortCircuitInfo(OpKernelConstruction * ctx,const NameAttrList & func,ShortCircuitInfo * info)166 Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
167 const NameAttrList& func,
168 ShortCircuitInfo* info) {
169 auto& indices = info->indices;
170
171 FunctionLibraryRuntime::Handle fn_handle;
172 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
173 func.name(), AttrSlice(&func.attr()), &fn_handle));
174 auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
175 Status s = ctx->function_library()->ReleaseHandle(fn_handle);
176 if (!s.ok()) {
177 LOG(WARNING) << "Failed to release handle: " << s.error_message();
178 }
179 });
180
181 // If the function contains any stateful operations, we conservatively execute
182 // the entire function.
183 if (ctx->function_library()->IsStateful(func.name())) {
184 return Status::OK();
185 }
186
187 const FunctionBody* fn_body =
188 ctx->function_library()->GetFunctionBody(fn_handle);
189 indices.resize(fn_body->ret_nodes.size());
190
191 for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
192 Node* ret_node = fn_body->ret_nodes[i];
193 Node* ret_input_node;
194 TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
195
196 while (ret_input_node->def().op() == "Identity") {
197 TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
198 }
199
200 if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
201 TF_RETURN_IF_ERROR(
202 GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
203 } else {
204 indices.clear();
205 break;
206 }
207 }
208
209 // Compute the `can_move` vector.
210 if (!indices.empty()) {
211 auto& can_move = info->can_move;
212 std::map<int, int> last_use;
213 for (size_t i = 0; i < indices.size(); ++i) {
214 last_use[indices[i]] = i;
215 }
216 can_move.resize(indices.size());
217 for (int i = 0, end = indices.size(); i < end; ++i) {
218 can_move[i] = last_use[indices[i]] == i;
219 }
220 }
221
222 return Status::OK();
223 }
224
CreateFunctionLibraryDefinition(const FunctionLibraryDefinition * lib_def,const string & func_name,std::unique_ptr<FunctionLibraryDefinition> * result)225 Status CreateFunctionLibraryDefinition(
226 const FunctionLibraryDefinition* lib_def, const string& func_name,
227 std::unique_ptr<FunctionLibraryDefinition>* result) {
228 DCHECK(lib_def != nullptr);
229 const FunctionDef* fdef = lib_def->Find(func_name);
230 if (TF_PREDICT_FALSE(fdef == nullptr)) {
231 return errors::FailedPrecondition(strings::StrCat(
232 "Could not find required function definition ", func_name));
233 }
234 *result = absl::make_unique<FunctionLibraryDefinition>(
235 lib_def->ReachableDefinitions(*fdef));
236 return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
237 }
238
LookupFunction(const FunctionLibraryDefinition & lib_def,const string & name,const FunctionDef ** fdef)239 Status LookupFunction(const FunctionLibraryDefinition& lib_def,
240 const string& name, const FunctionDef** fdef) {
241 *fdef = lib_def.Find(name);
242 if (*fdef == nullptr) {
243 return errors::InvalidArgument(
244 "Failed to find function ", name,
245 " in function library: ", lib_def.ToProto().DebugString());
246 }
247 return Status::OK();
248 }
249
250 class CallFrameBase : public CallFrameInterface {
251 public:
CallFrameBase(DataTypeSlice ret_types)252 explicit CallFrameBase(DataTypeSlice ret_types)
253 : ret_types_(ret_types), retvals_(ret_types.size()) {}
254
255 // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)256 Status ConsumeRetvals(std::vector<Tensor>* retvals) {
257 retvals->reserve(retvals_.size());
258 int i = 0;
259 for (auto&& val : retvals_) {
260 if (!val) {
261 return errors::Internal("No return value for index ", i, ".");
262 }
263 retvals->emplace_back(std::move(val.value()));
264 ++i;
265 }
266 return Status::OK();
267 }
268
num_retvals() const269 size_t num_retvals() const override { return retvals_.size(); }
270
271 // Callee methods.
SetRetval(int index,const Tensor & val)272 Status SetRetval(int index, const Tensor& val) override {
273 const int retvals_size = retvals_.size();
274 if (index < retvals_size && val.dtype() == ret_types_[index] &&
275 !retvals_[index]) {
276 retvals_[index] = val;
277 return Status::OK();
278 } else if (index >= retvals_size) {
279 return errors::InvalidArgument("Return value ", index,
280 " is out of range.");
281 } else if (val.dtype() != ret_types_[index]) {
282 return errors::InvalidArgument("Expected type ",
283 DataTypeString(ret_types_[index]),
284 " for return value ", index, " but got ",
285 DataTypeString(val.dtype()), ".");
286 } else {
287 return errors::Internal("Attempted to set return value ", index,
288 " more than once.");
289 }
290 }
291
292 private:
293 DataTypeSlice ret_types_;
294 std::vector<gtl::optional<Tensor>> retvals_;
295 TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
296 };
297
298 class OwnedArgsCallFrame : public CallFrameBase {
299 public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)300 OwnedArgsCallFrame(std::vector<Tensor>&& args,
301 const std::vector<Tensor>* captured_inputs,
302 DataTypeSlice ret_types)
303 : CallFrameBase(ret_types),
304 args_(std::move(args)),
305 captured_inputs_(captured_inputs) {}
306
num_args() const307 size_t num_args() const override {
308 return args_.size() + captured_inputs_->size();
309 }
310
311 // Callee methods.
GetArg(int index,const Tensor ** val)312 Status GetArg(int index, const Tensor** val) override {
313 const int args_size = args_.size();
314 const int captured_inputs_size = captured_inputs_->size();
315 if (index < args_size) {
316 *val = &args_[index];
317 return Status::OK();
318 } else if (index < args_size + captured_inputs_size) {
319 *val = &(*captured_inputs_)[index - args_.size()];
320 return Status::OK();
321 } else {
322 return errors::InvalidArgument("Argument ", index, " is out of range.");
323 }
324 }
325
326 // Since we own the argument tensors in `args_`, we can implement
327 // `ConsumeArg()` for those arguments.
ConsumeArg(int index,Tensor * val)328 void ConsumeArg(int index, Tensor* val) override {
329 DCHECK_GE(index, 0);
330 DCHECK_LT(index, args_.size());
331 *val = std::move(args_[index]);
332 }
CanConsumeArg(int index) const333 bool CanConsumeArg(int index) const override {
334 return index >= 0 && index < static_cast<int>(args_.size());
335 }
336
337 private:
338 std::vector<Tensor> args_;
339 const std::vector<Tensor>* const captured_inputs_; // Not owned.
340 };
341
342 class BorrowedArgsCallFrame : public CallFrameBase {
343 public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)344 BorrowedArgsCallFrame(const std::vector<Tensor>& args,
345 const std::vector<Tensor>* captured_inputs,
346 DataTypeSlice ret_types)
347 : CallFrameBase(ret_types),
348 args_(args),
349 captured_inputs_(captured_inputs) {}
350
num_args() const351 size_t num_args() const override {
352 return args_.size() + captured_inputs_->size();
353 }
354
355 // Callee methods.
GetArg(int index,const Tensor ** val)356 Status GetArg(int index, const Tensor** val) override {
357 const int args_size = args_.size();
358 const int captured_inputs_size = captured_inputs_->size();
359 if (index < args_size) {
360 *val = &args_[index];
361 return Status::OK();
362 } else if (index < args_size + captured_inputs_size) {
363 *val = &(*captured_inputs_)[index - args_size];
364 return Status::OK();
365 } else {
366 return errors::InvalidArgument("Argument ", index, " is out of range.");
367 }
368 }
369
370 private:
371 const std::vector<Tensor>& args_; // Not owned.
372 const std::vector<Tensor>* const captured_inputs_; // Not owned.
373 };
374
375 } // namespace
376
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)377 Status MakeIteratorFromInputElement(
378 IteratorContext* ctx, const IteratorBase* parent,
379 const std::vector<Tensor>& input_element, int64_t thread_index,
380 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
381 std::unique_ptr<IteratorBase>* out_iterator) {
382 return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
383 inst_captured_func, prefix, out_iterator,
384 /*node=*/nullptr);
385 }
386
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator,const std::shared_ptr<model::Node> & node)387 Status MakeIteratorFromInputElement(
388 IteratorContext* ctx, const IteratorBase* parent,
389 const std::vector<Tensor>& input_element, int64_t thread_index,
390 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
391 std::unique_ptr<IteratorBase>* out_iterator,
392 const std::shared_ptr<model::Node>& node) {
393 std::vector<Tensor> return_values;
394
395 TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
396 ctx, input_element, &return_values, node));
397
398 if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
399 TensorShapeUtils::IsScalar(return_values[0].shape()))) {
400 return errors::InvalidArgument(
401 "Function must return a single scalar of dtype DT_VARIANT.");
402 }
403
404 // Retrieve the dataset that was created in `f`.
405 DatasetBase* returned_dataset;
406 TF_RETURN_IF_ERROR(
407 GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
408
409 // Create an iterator for the dataset that was returned by `f`.
410 std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
411 if (ctx->split_providers().empty()) {
412 return returned_dataset->MakeIterator(ctx, parent, iterator_prefix,
413 out_iterator);
414 }
415 // Strip out the split providers so that they don't apply to sub-iterators.
416 IteratorContext::Params params(ctx);
417 params.split_providers.clear();
418 return returned_dataset->MakeIterator(IteratorContext(std::move(params)),
419 parent, iterator_prefix, out_iterator);
420 }
421
422 /* static */
Create(OpKernelConstruction * ctx,const string & func_name,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)423 Status FunctionMetadata::Create(
424 OpKernelConstruction* ctx, const string& func_name, Params params,
425 std::shared_ptr<FunctionMetadata>* out_metadata) {
426 NameAttrList func;
427 TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
428 return Create(ctx, std::move(func), params, out_metadata);
429 }
430
Create(OpKernelConstruction * ctx,NameAttrList && func,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)431 Status FunctionMetadata::Create(
432 OpKernelConstruction* ctx, NameAttrList&& func, Params params,
433 std::shared_ptr<FunctionMetadata>* out_metadata) {
434 out_metadata->reset(new FunctionMetadata(std::move(func), params));
435 TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
436 ctx->function_library()->GetFunctionLibraryDefinition(),
437 (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
438 TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
439 ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
440 const FunctionDef* fdef;
441 TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
442 (*out_metadata)->func().name(), &fdef));
443
444 auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
445 if (attr != fdef->attr().end() && attr->second.b()) {
446 VLOG(1) << "Disabling multi-device execution for a function that uses the "
447 << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
448 (*out_metadata)->use_multi_device_function_ = false;
449 return Status::OK();
450 }
451 auto validate_arg = [](const OpDef::ArgDef& arg) {
452 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
453 VLOG(1) << "Disabling multi-device execution for a function with "
454 << "a vector argument " << arg.name() << ".";
455 return false;
456 }
457 return true;
458 };
459 for (const auto& arg : fdef->signature().input_arg()) {
460 if (!validate_arg(arg)) {
461 (*out_metadata)->use_multi_device_function_ = false;
462 return Status::OK();
463 }
464 }
465 for (const auto& arg : fdef->signature().output_arg()) {
466 if (!validate_arg(arg)) {
467 (*out_metadata)->use_multi_device_function_ = false;
468 return Status::OK();
469 }
470 }
471 return Status::OK();
472 }
473
474 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)475 Status CapturedFunction::Create(
476 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
477 const string& argument_name,
478 std::unique_ptr<CapturedFunction>* out_function) {
479 OpInputList inputs;
480 TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
481 std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
482 return Create(ctx, std::move(metadata), std::move(captured_inputs),
483 out_function);
484 }
485
486 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> && captured_inputs,std::unique_ptr<CapturedFunction> * out_function)487 Status CapturedFunction::Create(
488 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
489 std::vector<Tensor>&& captured_inputs,
490 std::unique_ptr<CapturedFunction>* out_function) {
491 *out_function = absl::WrapUnique(
492 new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
493 return Status::OK();
494 }
495
AddToGraph(SerializationContext * ctx,DatasetBase::DatasetGraphDefBuilder * b,std::vector<Node * > * other_arguments,DataTypeVector * other_arguments_types) const496 Status CapturedFunction::AddToGraph(
497 SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
498 std::vector<Node*>* other_arguments,
499 DataTypeVector* other_arguments_types) const {
500 other_arguments->reserve(captured_inputs_.size());
501 other_arguments_types->reserve(captured_inputs_.size());
502 for (const Tensor& t : captured_inputs_) {
503 Node* node;
504 if (ctx->serialize_data_tensors()) {
505 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
506 } else {
507 TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
508 DCHECK_NE(ctx->input_list(), nullptr);
509 ctx->input_list()->emplace_back(node->name(), t);
510 }
511 other_arguments->emplace_back(node);
512 other_arguments_types->emplace_back(t.dtype());
513 }
514 TF_RETURN_IF_ERROR(
515 b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
516 return Status::OK();
517 }
518
519 // TODO(b/190831948): Check whether the function creates a resource and if so,
520 // produce a warning.
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)521 Status CapturedFunction::Instantiate(
522 IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
523 instantiated_captured_function) {
524 // The context's runtime will be used for all subsequent calls.
525 FunctionLibraryRuntime* lib = ctx->flr();
526 FunctionLibraryRuntime::InstantiateOptions inst_opts;
527 inst_opts.lib_def = metadata_->lib_def();
528 inst_opts.create_kernels_eagerly = true;
529 inst_opts.default_device_to_target = metadata_->use_default_device();
530 inst_opts.config_proto =
531 lib->config_proto() ? *lib->config_proto() : ConfigProto();
532 if (!metadata_->use_inter_op_parallelism()) {
533 inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
534 }
535 inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
536 if (!ctx->function_handle_cache()) {
537 // If the caller does not provide a cache, we use the FLR cache.
538 inst_opts.use_function_cache = true;
539 }
540
541 // We infer the target device from the function library runtime.
542 DCHECK(lib->device() != nullptr);
543 inst_opts.target = lib->device()->name();
544
545 // Maps from a CompositeDevice name to underlying physical device names.
546 absl::flat_hash_map<string, std::vector<string>> composite_devices;
547
548 if (inst_opts.is_multi_device_function) {
549 // Compute devices of non-captured inputs.
550 //
551 // We infer the number of non-captured inputs by subtracting the number
552 // of captured inputs from the number of input arguments and we infer the
553 // input devices from the function library runtime.
554 const FunctionDef* fdef;
555 TF_RETURN_IF_ERROR(
556 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
557 size_t num_non_captured_inputs =
558 fdef->signature().input_arg_size() - captured_inputs_.size();
559 for (size_t i = 0; i < num_non_captured_inputs; ++i) {
560 inst_opts.input_devices.push_back(inst_opts.target);
561 }
562 // Compute devices of captured inputs.
563 // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
564 Device* cpu_device;
565 TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
566 std::unordered_map<int, DtypeAndPartialTensorShape>&
567 input_resource_variable_dtypes_and_shapes =
568 inst_opts.input_resource_dtypes_and_shapes;
569 for (size_t i = 0; i < captured_inputs_.size(); ++i) {
570 const auto& input = captured_inputs_[i];
571 DataType dtype = input.dtype();
572 if (dtype == DT_RESOURCE) {
573 const auto& handles = input.flat<ResourceHandle>();
574 const ResourceHandle& handle0 = handles(0);
575 string composite_device;
576 auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
577 if (iter != fdef->arg_attr().end()) {
578 auto arg_attr = iter->second.attr().find("_composite_device");
579 if (arg_attr != iter->second.attr().end()) {
580 composite_device = arg_attr->second.s();
581 }
582 }
583 if (!composite_device.empty()) {
584 if (composite_devices.find(composite_device) ==
585 composite_devices.end()) {
586 for (int i = 0; i < handles.size(); ++i) {
587 composite_devices[composite_device].push_back(
588 handles(i).device());
589 }
590 }
591 inst_opts.input_devices.push_back(composite_device);
592 } else {
593 inst_opts.input_devices.push_back(handle0.device());
594 }
595 const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
596 // Set dtypes and shapes for resource variable inputs.
597 if (!dtypes_and_shapes.empty()) {
598 input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
599 i] =
600 dtypes_and_shapes.at(0);
601 }
602 } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
603 inst_opts.input_devices.push_back(cpu_device->name());
604 } else {
605 // Fall back to using the function library runtime device.
606 inst_opts.input_devices.push_back(inst_opts.target);
607 }
608 }
609
610 for (const auto& it : composite_devices) {
611 inst_opts.composite_devices[it.first] = &it.second;
612 }
613
614 for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
615 inst_opts.output_devices.push_back(inst_opts.target);
616 }
617
618 #if !defined(IS_MOBILE_PLATFORM)
619 grappler::GrapplerItem::OptimizationOptions optimization_options;
620 optimization_options.allow_pruning_stateful_and_dataset_ops = false;
621 ConfigProto config_proto = inst_opts.config_proto;
622 // Layout optimizations are excluded because they assume that ops without
623 // explicit device assignment will be placed on GPU (if available) but
624 // that's not the case for operations within tf.data functions.
625 config_proto.mutable_graph_options()
626 ->mutable_rewrite_options()
627 ->set_layout_optimizer(RewriterConfig::OFF);
628 // TODO(b/120437209): Re-enable constant folding.
629 config_proto.mutable_graph_options()
630 ->mutable_rewrite_options()
631 ->set_constant_folding(RewriterConfig::OFF);
632 inst_opts.optimize_graph_fn =
633 std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
634 std::placeholders::_2, std::placeholders::_3,
635 std::placeholders::_4, std::placeholders::_5,
636 std::move(config_proto), fdef->signature().name(),
637 std::move(optimization_options), std::placeholders::_6);
638 #endif // !IS_MOBILE_PLATFORM
639 }
640
641 FunctionLibraryRuntime::Handle f_handle;
642 if (ctx->function_handle_cache()) {
643 TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
644 metadata_->func().name(), AttrSlice(&metadata_->func().attr()),
645 inst_opts, &f_handle));
646 } else {
647 TF_RETURN_IF_ERROR(lib->Instantiate(metadata_->func().name(),
648 AttrSlice(&metadata_->func().attr()),
649 inst_opts, &f_handle));
650 }
651
652 DataTypeVector ret_types;
653 TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
654
655 bool is_multi_device;
656 TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
657 return InstantiatedCapturedFunction::Create(
658 lib, f_handle, std::move(ret_types), *ctx->runner(), this,
659 is_multi_device, instantiated_captured_function);
660 }
661
CheckExternalState() const662 Status CapturedFunction::CheckExternalState() const {
663 for (const auto& name : lib_def()->ListFunctionNames()) {
664 TF_RETURN_IF_ERROR(
665 IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
666 }
667 return Status::OK();
668 }
669
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> captured_inputs)670 CapturedFunction::CapturedFunction(
671 std::shared_ptr<const FunctionMetadata> metadata,
672 std::vector<Tensor> captured_inputs)
673 : metadata_(std::move(metadata)),
674 captured_inputs_(std::move(captured_inputs)) {}
675
IsMultiDevice(IteratorContext * ctx,bool * is_multi_device) const676 Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
677 bool* is_multi_device) const {
678 if (!metadata_->use_multi_device_function()) {
679 *is_multi_device = false;
680 return Status::OK();
681 }
682
683 const FunctionDef* fdef;
684 TF_RETURN_IF_ERROR(
685 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
686
687 Device* current_device = ctx->flr()->device();
688 DeviceType current_device_type(current_device->device_type());
689 DeviceNameUtils::ParsedName current_device_name;
690 if (!DeviceNameUtils::ParseFullName(current_device->name(),
691 ¤t_device_name)) {
692 return errors::InvalidArgument("Failed to parse device name: ",
693 current_device->name());
694 }
695
696 // Check if any of the captured inputs are placed on a device not compatible
697 // with the current device. For non-captured inputs, we assume they are placed
698 // on the current device.
699 for (const auto& input : captured_inputs_) {
700 DataType dtype = input.dtype();
701 if (dtype == DT_RESOURCE) {
702 const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
703 DeviceNameUtils::ParsedName resource_device_name;
704 if (!DeviceNameUtils::ParseFullName(handle.device(),
705 &resource_device_name)) {
706 return errors::InvalidArgument("Failed to parse device name: ",
707 handle.device());
708 }
709 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
710 resource_device_name)) {
711 *is_multi_device = true;
712 return Status::OK();
713 }
714 }
715 }
716
717 // Check if all ops could be placed on the current device.
718 for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
719 const FunctionDef* fdef;
720 TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
721 for (const auto& node : fdef->node_def()) {
722 // Check if the op has a kernel available for the current device.
723 if (!KernelDefAvailable(current_device_type, node)) {
724 *is_multi_device = true;
725 return Status::OK();
726 }
727 // If the op has a requested device, check if the requested device is
728 // compatible with the current device.
729 if (!node.device().empty()) {
730 DeviceNameUtils::ParsedName node_device_name;
731 if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
732 return errors::InvalidArgument("Failed to parse device name: ",
733 node.device());
734 }
735 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
736 node_device_name)) {
737 *is_multi_device = true;
738 return Status::OK();
739 }
740 }
741 }
742 }
743
744 *is_multi_device = false;
745 return Status::OK();
746 }
747
748 /* static */
Create(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device,std::unique_ptr<InstantiatedCapturedFunction> * out_function)749 Status InstantiatedCapturedFunction::Create(
750 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
751 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
752 CapturedFunction* captured_func, bool is_multi_device,
753 std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
754 out_function->reset(new InstantiatedCapturedFunction(
755 lib, f_handle, ret_types, runner, captured_func, is_multi_device));
756 return Status::OK();
757 }
758
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device)759 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
760 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
761 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
762 CapturedFunction* captured_func, bool is_multi_device)
763 : lib_(lib),
764 f_handle_(f_handle),
765 ret_types_(std::move(ret_types)),
766 captured_runner_(std::move(runner)),
767 captured_func_(captured_func),
768 is_multi_device_(is_multi_device) {}
769
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const770 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
771 std::vector<Tensor>&& args,
772 std::vector<Tensor>* rets) const {
773 return Run(ctx, std::move(args), rets, /*node=*/nullptr);
774 }
775
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const776 Status InstantiatedCapturedFunction::Run(
777 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
778 const std::shared_ptr<model::Node>& node) const {
779 auto& info = captured_func_->short_circuit_info();
780 if (!info.indices.empty()) {
781 return RunShortCircuit(info, std::move(args), captured_func_, rets);
782 }
783
784 FunctionLibraryRuntime::Options f_opts;
785 ScopedStepContainer step_container(
786 f_opts.step_id, [this](const string& name) {
787 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
788 });
789 f_opts.step_container = &step_container;
790 f_opts.runner = ctx->runner();
791 f_opts.create_rendezvous = ShouldCreateRendezvous();
792 CancellationManager cancellation_manager(ctx->cancellation_manager());
793 f_opts.cancellation_manager = &cancellation_manager;
794 f_opts.collective_executor = ctx->collective_executor();
795
796 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
797 if (node || ctx->stats_aggregator()) {
798 stats_collector = std::make_shared<SimpleStepStatsCollector>();
799 }
800 const bool collect_usage =
801 node && ctx->model() && ctx->model()->collect_resource_usage();
802 f_opts.stats_collector = stats_collector.get();
803
804 OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
805 ret_types_);
806 profiler::TraceMe activity(
807 [&] {
808 return profiler::TraceMeEncode("InstantiatedCapturedFunction::Run",
809 {{"id", f_opts.step_id}});
810 },
811 profiler::TraceMeLevel::kInfo);
812 if (node) {
813 // Resource usage for function execution is gathered from the executor.
814 // TODO(jsimsa): Factor out common code for Run, RunAsync, and
815 // RunWithBorrowedArguments
816 if (collect_usage) node->record_stop(EnvTime::NowNanos());
817 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
818 if (ctx->stats_aggregator()) {
819 string prefix_with_func_name = strings::StrCat(
820 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
821 ctx->stats_aggregator()->AddToHistogram(
822 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
823 {static_cast<float>(stats_collector->processing_time())},
824 node->num_elements());
825 }
826 node->add_processing_time(stats_collector->processing_time());
827 if (collect_usage) node->record_start(EnvTime::NowNanos());
828 } else {
829 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
830 }
831 return frame.ConsumeRetvals(rets);
832 }
833
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * ret) const834 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
835 IteratorContext* ctx, const std::vector<Tensor>& args,
836 std::vector<Tensor>* ret) const {
837 return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
838 }
839
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const840 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
841 IteratorContext* ctx, const std::vector<Tensor>& args,
842 std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
843 auto& info = captured_func_->short_circuit_info();
844 if (!info.indices.empty()) {
845 return RunShortCircuit(info, args, captured_func_, rets);
846 }
847
848 FunctionLibraryRuntime::Options f_opts;
849 ScopedStepContainer step_container(
850 f_opts.step_id, [this](const string& name) {
851 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
852 });
853 f_opts.step_container = &step_container;
854 f_opts.runner = ctx->runner();
855 f_opts.create_rendezvous = ShouldCreateRendezvous();
856 CancellationManager cancellation_manager(ctx->cancellation_manager());
857 f_opts.cancellation_manager = &cancellation_manager;
858 f_opts.collective_executor = ctx->collective_executor();
859
860 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
861 if (node || ctx->stats_aggregator()) {
862 stats_collector = std::make_shared<SimpleStepStatsCollector>();
863 }
864 const bool collect_usage =
865 node && ctx->model() && ctx->model()->collect_resource_usage();
866 f_opts.stats_collector = stats_collector.get();
867
868 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
869 ret_types_);
870 profiler::TraceMe activity(
871 [&] {
872 return profiler::TraceMeEncode(
873 "InstantiatedCapturedFunction::RunWithBorrowedArgs",
874 {{"id", f_opts.step_id}});
875 },
876 profiler::TraceMeLevel::kInfo);
877 if (node) {
878 // Resource usage for function execution is gathered from the executor.
879 if (collect_usage) node->record_stop(EnvTime::NowNanos());
880 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
881 if (ctx->stats_aggregator()) {
882 string prefix_with_func_name = strings::StrCat(
883 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
884 ctx->stats_aggregator()->AddToHistogram(
885 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
886 {static_cast<float>(stats_collector->processing_time())},
887 node->num_elements());
888 }
889 node->add_processing_time(stats_collector->processing_time());
890 if (collect_usage) node->record_start(EnvTime::NowNanos());
891 } else {
892 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
893 }
894 return frame.ConsumeRetvals(rets);
895 }
896
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)897 Status InstantiatedCapturedFunction::RunInstantiated(
898 const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
899 auto& info = captured_func_->short_circuit_info();
900 if (!info.indices.empty()) {
901 return RunShortCircuit(info, args, captured_func_, rets);
902 }
903
904 FunctionLibraryRuntime::Options f_opts;
905 ScopedStepContainer step_container(
906 f_opts.step_id, [this](const string& name) {
907 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
908 });
909 f_opts.step_container = &step_container;
910 f_opts.runner = &captured_runner_;
911 f_opts.create_rendezvous = ShouldCreateRendezvous();
912 CancellationManager cancellation_manager;
913 f_opts.cancellation_manager = &cancellation_manager;
914
915 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
916 ret_types_);
917 profiler::TraceMe activity(
918 [&] {
919 return profiler::TraceMeEncode(
920 "InstantiatedCapturedFunction::RunInstantiated",
921 {{"id", f_opts.step_id}});
922 },
923 profiler::TraceMeLevel::kInfo);
924 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
925 return frame.ConsumeRetvals(rets);
926 }
927
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const std::shared_ptr<model::Node> & node) const928 void InstantiatedCapturedFunction::RunAsync(
929 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
930 FunctionLibraryRuntime::DoneCallback done,
931 const std::shared_ptr<model::Node>& node) const {
932 auto& info = captured_func_->short_circuit_info();
933 if (!info.indices.empty()) {
934 // Run the `done` callback on a threadpool thread, because it will
935 // potentially do a non-trivial amount of (e.g. copying) work, and we may
936 // want to run that concurrently with the next invocation.
937 Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
938 (*ctx->runner())(
939 std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
940 std::move(done)));
941 return;
942 }
943
944 // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
945 // be deleted before `done` is called. Take care not to capture `ctx` in any
946 // code that may execute asynchronously in this function.
947 OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
948 std::move(args), &captured_func_->captured_inputs(), ret_types_);
949
950 FunctionLibraryRuntime::Options f_opts;
951 ResourceMgr* resource_mgr = lib_->device()->resource_manager();
952 ScopedStepContainer* step_container = new ScopedStepContainer(
953 f_opts.step_id, [resource_mgr](const string& name) {
954 resource_mgr->Cleanup(name).IgnoreError();
955 });
956 f_opts.step_container = step_container;
957 f_opts.runner = ctx->runner();
958 f_opts.create_rendezvous = ShouldCreateRendezvous();
959 auto cancellation_manager =
960 absl::make_unique<CancellationManager>(ctx->cancellation_manager());
961 f_opts.cancellation_manager = cancellation_manager.get();
962 f_opts.collective_executor = ctx->collective_executor();
963
964 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
965 if (node || ctx->stats_aggregator()) {
966 stats_collector = std::make_shared<SimpleStepStatsCollector>();
967 }
968 const bool collect_usage =
969 node && ctx->model() && ctx->model()->collect_resource_usage();
970 f_opts.stats_collector = stats_collector.get();
971
972 // Transfer ownership of the cancellation manager to `callback`.
973 CancellationManager* raw_cancellation_manager =
974 cancellation_manager.release();
975 auto callback = std::bind(
976 [this, rets, step_container, raw_cancellation_manager, frame, node,
977 collect_usage](
978 const FunctionLibraryRuntime::DoneCallback& done,
979 IteratorContext* ctx,
980 const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
981 // Begin unbound arguments.
982 Status s) {
983 delete step_container;
984 delete raw_cancellation_manager;
985 if (s.ok()) {
986 s = frame->ConsumeRetvals(rets);
987 }
988 delete frame;
989 if (node) {
990 // TODO(b/129085499) Utilize the `node_name` which would be unique
991 // than the prefix for the function execution time statistics.
992 // prefix_with_func_name would then be node_name + func_name.
993 if (ctx->stats_aggregator()) {
994 string prefix_with_func_name =
995 strings::StrCat(node->name(), stats_utils::kDelimiter,
996 captured_func_->func().name());
997 ctx->stats_aggregator()->AddToHistogram(
998 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
999 {static_cast<float>(stats_collector->processing_time())},
1000 node->num_elements());
1001 }
1002 node->add_processing_time(stats_collector->processing_time());
1003 }
1004 if (collect_usage) {
1005 node->record_start(EnvTime::NowNanos());
1006 }
1007 done(s);
1008 if (collect_usage) {
1009 node->record_stop(EnvTime::NowNanos());
1010 }
1011 },
1012 std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1013
1014 profiler::TraceMe activity(
1015 [&] {
1016 return profiler::TraceMeEncode("InstantiatedCapturedFunction::RunAsync",
1017 {{"id", f_opts.step_id}});
1018 },
1019 profiler::TraceMeLevel::kInfo);
1020 // Stop the usage collection before calling `Run()` because `callback` may
1021 // be executed synchronously, and so the `node->record_start()` call within
1022 // `callback` would violate nesting.
1023 if (collect_usage) node->record_stop(EnvTime::NowNanos());
1024 lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1025 if (collect_usage) node->record_start(EnvTime::NowNanos());
1026 }
1027
ShouldCreateRendezvous() const1028 bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1029 // Rendezvous should only be created by the FLR for non-CPU single-device
1030 // functions. For multi-device functions the appropriate rendezvous will be
1031 // created by the process FLR.
1032 return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1033 }
1034
1035 } // namespace data
1036 } // namespace tensorflow
1037