• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/captured_function.h"
16 
17 #include <utility>
18 
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/step_stats_collector.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/stats_utils.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function_handle_cache.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/lib/random/random.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/notification.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41 #endif  // !IS_MOBILE_PLATFORM
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
47 // Simplistic implementation of the `StepStatsCollectorInterface` that only
48 // cares about collecting the CPU time needed to execute a captured function.
49 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
50  public:
IncrementProcessingTime(int64_t delta)51   void IncrementProcessingTime(int64_t delta) {
52     mutex_lock l(mu_);
53     processing_time_ += delta;
54   }
55 
CreateNodeExecStats(const NodeDef * node)56   NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
57     return new SimpleNodeExecStats(this);
58   }
59 
ReportAllocsOnResourceExhausted(const string & err)60   string ReportAllocsOnResourceExhausted(const string& err) override {
61     return "";
62   }
63 
processing_time()64   int64 processing_time() {
65     tf_shared_lock l(mu_);
66     return processing_time_;
67   }
68 
69  private:
70   class SimpleNodeExecStats : public NodeExecStatsInterface {
71    public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)72     explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
73         : step_stats_collector_(step_stats_collector) {}
74 
Done(const string & device)75     void Done(const string& device) override {
76       step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
77                                                      start_time_ns_);
78       delete this;
79     }
80 
RecordExecutorStarted()81     void RecordExecutorStarted() override {
82       start_time_ns_ = absl::GetCurrentTimeNanos();
83     }
84 
RecordComputeStarted()85     void RecordComputeStarted() override {}
86 
RecordComputeEnded()87     void RecordComputeEnded() override {}
88 
RecordExecutorEnded()89     void RecordExecutorEnded() override {
90       end_time_ns_ = absl::GetCurrentTimeNanos();
91     }
92 
TrackAllocations() const93     bool TrackAllocations() const override { return false; }
94 
SetMemory(OpKernelContext * ctx)95     void SetMemory(OpKernelContext* ctx) override {}
96 
SetOutput(int slot,const Tensor * tensor)97     void SetOutput(int slot, const Tensor* tensor) override {}
98 
SetScheduled(int64_t nanos)99     void SetScheduled(int64_t nanos) override {}
100 
101    private:
102     int64 start_time_ns_ = 0;
103     int64 end_time_ns_ = 0;
104     SimpleStepStatsCollector* step_stats_collector_;  // Not owned.
105   };
106 
107   mutex mu_;
108   int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
109 };
110 
GetCapturedInput(const CapturedFunction * const func,int index,const Tensor ** out)111 Status GetCapturedInput(const CapturedFunction* const func, int index,
112                         const Tensor** out) {
113   if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
114     return errors::OutOfRange(
115         "Out of range access to captured inputs for function ",
116         func->func().name(), ". Index: ", index,
117         ". Num captured inputs: ", func->captured_inputs().size());
118   }
119   *out = &func->captured_inputs()[index];
120   return Status::OK();
121 }
122 
RunShortCircuit(const ShortCircuitInfo & info,const std::vector<Tensor> & args,const CapturedFunction * const func,std::vector<Tensor> * rets)123 Status RunShortCircuit(const ShortCircuitInfo& info,
124                        const std::vector<Tensor>& args,
125                        const CapturedFunction* const func,
126                        std::vector<Tensor>* rets) {
127   VLOG(3) << "Running function " << func->func().name() << " short circuit";
128   const int num_args = args.size();
129   rets->reserve(info.indices.size());
130   for (size_t i = 0; i < info.indices.size(); ++i) {
131     if (info.indices[i] < num_args) {
132       rets->push_back(args[info.indices[i]]);
133     } else {
134       const Tensor* captured_input;
135       TF_RETURN_IF_ERROR(
136           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
137       rets->push_back(*captured_input);
138     }
139   }
140   return Status::OK();
141 }
142 
RunShortCircuit(const ShortCircuitInfo & info,std::vector<Tensor> && args,const CapturedFunction * const func,std::vector<Tensor> * rets)143 Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
144                        const CapturedFunction* const func,
145                        std::vector<Tensor>* rets) {
146   VLOG(3) << "Running function " << func->func().name() << " short circuit";
147   const int num_args = args.size();
148   rets->reserve(info.indices.size());
149   for (size_t i = 0; i < info.indices.size(); ++i) {
150     if (info.indices[i] < num_args) {
151       if (info.can_move[i]) {
152         rets->push_back(std::move(args[info.indices[i]]));
153       } else {
154         rets->push_back(args[info.indices[i]]);
155       }
156     } else {
157       const Tensor* captured_input;
158       TF_RETURN_IF_ERROR(
159           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
160       rets->push_back(*captured_input);
161     }
162   }
163   return Status::OK();
164 }
165 
CreateShortCircuitInfo(OpKernelConstruction * ctx,const NameAttrList & func,ShortCircuitInfo * info)166 Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
167                               const NameAttrList& func,
168                               ShortCircuitInfo* info) {
169   auto& indices = info->indices;
170 
171   FunctionLibraryRuntime::Handle fn_handle;
172   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
173       func.name(), AttrSlice(&func.attr()), &fn_handle));
174   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
175     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
176     if (!s.ok()) {
177       LOG(WARNING) << "Failed to release handle: " << s.error_message();
178     }
179   });
180 
181   // If the function contains any stateful operations, we conservatively execute
182   // the entire function.
183   if (ctx->function_library()->IsStateful(func.name())) {
184     return Status::OK();
185   }
186 
187   const FunctionBody* fn_body =
188       ctx->function_library()->GetFunctionBody(fn_handle);
189   indices.resize(fn_body->ret_nodes.size());
190 
191   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
192     Node* ret_node = fn_body->ret_nodes[i];
193     Node* ret_input_node;
194     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
195 
196     while (ret_input_node->def().op() == "Identity") {
197       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
198     }
199 
200     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
201       TF_RETURN_IF_ERROR(
202           GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
203     } else {
204       indices.clear();
205       break;
206     }
207   }
208 
209   // Compute the `can_move` vector.
210   if (!indices.empty()) {
211     auto& can_move = info->can_move;
212     std::map<int, int> last_use;
213     for (size_t i = 0; i < indices.size(); ++i) {
214       last_use[indices[i]] = i;
215     }
216     can_move.resize(indices.size());
217     for (int i = 0, end = indices.size(); i < end; ++i) {
218       can_move[i] = last_use[indices[i]] == i;
219     }
220   }
221 
222   return Status::OK();
223 }
224 
CreateFunctionLibraryDefinition(const FunctionLibraryDefinition * lib_def,const string & func_name,std::unique_ptr<FunctionLibraryDefinition> * result)225 Status CreateFunctionLibraryDefinition(
226     const FunctionLibraryDefinition* lib_def, const string& func_name,
227     std::unique_ptr<FunctionLibraryDefinition>* result) {
228   DCHECK(lib_def != nullptr);
229   const FunctionDef* fdef = lib_def->Find(func_name);
230   if (TF_PREDICT_FALSE(fdef == nullptr)) {
231     return errors::FailedPrecondition(strings::StrCat(
232         "Could not find required function definition ", func_name));
233   }
234   *result = absl::make_unique<FunctionLibraryDefinition>(
235       lib_def->ReachableDefinitions(*fdef));
236   return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
237 }
238 
LookupFunction(const FunctionLibraryDefinition & lib_def,const string & name,const FunctionDef ** fdef)239 Status LookupFunction(const FunctionLibraryDefinition& lib_def,
240                       const string& name, const FunctionDef** fdef) {
241   *fdef = lib_def.Find(name);
242   if (*fdef == nullptr) {
243     return errors::InvalidArgument(
244         "Failed to find function ", name,
245         " in function library: ", lib_def.ToProto().DebugString());
246   }
247   return Status::OK();
248 }
249 
250 class CallFrameBase : public CallFrameInterface {
251  public:
CallFrameBase(DataTypeSlice ret_types)252   explicit CallFrameBase(DataTypeSlice ret_types)
253       : ret_types_(ret_types), retvals_(ret_types.size()) {}
254 
255   // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)256   Status ConsumeRetvals(std::vector<Tensor>* retvals) {
257     retvals->reserve(retvals_.size());
258     int i = 0;
259     for (auto&& val : retvals_) {
260       if (!val) {
261         return errors::Internal("No return value for index ", i, ".");
262       }
263       retvals->emplace_back(std::move(val.value()));
264       ++i;
265     }
266     return Status::OK();
267   }
268 
num_retvals() const269   size_t num_retvals() const override { return retvals_.size(); }
270 
271   // Callee methods.
SetRetval(int index,const Tensor & val)272   Status SetRetval(int index, const Tensor& val) override {
273     const int retvals_size = retvals_.size();
274     if (index < retvals_size && val.dtype() == ret_types_[index] &&
275         !retvals_[index]) {
276       retvals_[index] = val;
277       return Status::OK();
278     } else if (index >= retvals_size) {
279       return errors::InvalidArgument("Return value ", index,
280                                      " is out of range.");
281     } else if (val.dtype() != ret_types_[index]) {
282       return errors::InvalidArgument("Expected type ",
283                                      DataTypeString(ret_types_[index]),
284                                      " for return value ", index, " but got ",
285                                      DataTypeString(val.dtype()), ".");
286     } else {
287       return errors::Internal("Attempted to set return value ", index,
288                               " more than once.");
289     }
290   }
291 
292  private:
293   DataTypeSlice ret_types_;
294   std::vector<gtl::optional<Tensor>> retvals_;
295   TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
296 };
297 
298 class OwnedArgsCallFrame : public CallFrameBase {
299  public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)300   OwnedArgsCallFrame(std::vector<Tensor>&& args,
301                      const std::vector<Tensor>* captured_inputs,
302                      DataTypeSlice ret_types)
303       : CallFrameBase(ret_types),
304         args_(std::move(args)),
305         captured_inputs_(captured_inputs) {}
306 
num_args() const307   size_t num_args() const override {
308     return args_.size() + captured_inputs_->size();
309   }
310 
311   // Callee methods.
GetArg(int index,const Tensor ** val)312   Status GetArg(int index, const Tensor** val) override {
313     const int args_size = args_.size();
314     const int captured_inputs_size = captured_inputs_->size();
315     if (index < args_size) {
316       *val = &args_[index];
317       return Status::OK();
318     } else if (index < args_size + captured_inputs_size) {
319       *val = &(*captured_inputs_)[index - args_.size()];
320       return Status::OK();
321     } else {
322       return errors::InvalidArgument("Argument ", index, " is out of range.");
323     }
324   }
325 
326   // Since we own the argument tensors in `args_`, we can implement
327   // `ConsumeArg()` for those arguments.
ConsumeArg(int index,Tensor * val)328   void ConsumeArg(int index, Tensor* val) override {
329     DCHECK_GE(index, 0);
330     DCHECK_LT(index, args_.size());
331     *val = std::move(args_[index]);
332   }
CanConsumeArg(int index) const333   bool CanConsumeArg(int index) const override {
334     return index >= 0 && index < static_cast<int>(args_.size());
335   }
336 
337  private:
338   std::vector<Tensor> args_;
339   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
340 };
341 
342 class BorrowedArgsCallFrame : public CallFrameBase {
343  public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)344   BorrowedArgsCallFrame(const std::vector<Tensor>& args,
345                         const std::vector<Tensor>* captured_inputs,
346                         DataTypeSlice ret_types)
347       : CallFrameBase(ret_types),
348         args_(args),
349         captured_inputs_(captured_inputs) {}
350 
num_args() const351   size_t num_args() const override {
352     return args_.size() + captured_inputs_->size();
353   }
354 
355   // Callee methods.
GetArg(int index,const Tensor ** val)356   Status GetArg(int index, const Tensor** val) override {
357     const int args_size = args_.size();
358     const int captured_inputs_size = captured_inputs_->size();
359     if (index < args_size) {
360       *val = &args_[index];
361       return Status::OK();
362     } else if (index < args_size + captured_inputs_size) {
363       *val = &(*captured_inputs_)[index - args_size];
364       return Status::OK();
365     } else {
366       return errors::InvalidArgument("Argument ", index, " is out of range.");
367     }
368   }
369 
370  private:
371   const std::vector<Tensor>& args_;                   // Not owned.
372   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
373 };
374 
375 }  // namespace
376 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)377 Status MakeIteratorFromInputElement(
378     IteratorContext* ctx, const IteratorBase* parent,
379     const std::vector<Tensor>& input_element, int64_t thread_index,
380     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
381     std::unique_ptr<IteratorBase>* out_iterator) {
382   return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
383                                       inst_captured_func, prefix, out_iterator,
384                                       /*node=*/nullptr);
385 }
386 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator,const std::shared_ptr<model::Node> & node)387 Status MakeIteratorFromInputElement(
388     IteratorContext* ctx, const IteratorBase* parent,
389     const std::vector<Tensor>& input_element, int64_t thread_index,
390     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
391     std::unique_ptr<IteratorBase>* out_iterator,
392     const std::shared_ptr<model::Node>& node) {
393   std::vector<Tensor> return_values;
394 
395   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
396       ctx, input_element, &return_values, node));
397 
398   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
399         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
400     return errors::InvalidArgument(
401         "Function must return a single scalar of dtype DT_VARIANT.");
402   }
403 
404   // Retrieve the dataset that was created in `f`.
405   DatasetBase* returned_dataset;
406   TF_RETURN_IF_ERROR(
407       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
408 
409   // Create an iterator for the dataset that was returned by `f`.
410   std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
411   if (ctx->split_providers().empty()) {
412     return returned_dataset->MakeIterator(ctx, parent, iterator_prefix,
413                                           out_iterator);
414   }
415   // Strip out the split providers so that they don't apply to sub-iterators.
416   IteratorContext::Params params(ctx);
417   params.split_providers.clear();
418   return returned_dataset->MakeIterator(IteratorContext(std::move(params)),
419                                         parent, iterator_prefix, out_iterator);
420 }
421 
422 /* static */
Create(OpKernelConstruction * ctx,const string & func_name,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)423 Status FunctionMetadata::Create(
424     OpKernelConstruction* ctx, const string& func_name, Params params,
425     std::shared_ptr<FunctionMetadata>* out_metadata) {
426   NameAttrList func;
427   TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
428   return Create(ctx, std::move(func), params, out_metadata);
429 }
430 
Create(OpKernelConstruction * ctx,NameAttrList && func,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)431 Status FunctionMetadata::Create(
432     OpKernelConstruction* ctx, NameAttrList&& func, Params params,
433     std::shared_ptr<FunctionMetadata>* out_metadata) {
434   out_metadata->reset(new FunctionMetadata(std::move(func), params));
435   TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
436       ctx->function_library()->GetFunctionLibraryDefinition(),
437       (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
438   TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
439       ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
440   const FunctionDef* fdef;
441   TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
442                                     (*out_metadata)->func().name(), &fdef));
443 
444   auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
445   if (attr != fdef->attr().end() && attr->second.b()) {
446     VLOG(1) << "Disabling multi-device execution for a function that uses the "
447             << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
448     (*out_metadata)->use_multi_device_function_ = false;
449     return Status::OK();
450   }
451   auto validate_arg = [](const OpDef::ArgDef& arg) {
452     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
453       VLOG(1) << "Disabling multi-device execution for a function with "
454               << "a vector argument " << arg.name() << ".";
455       return false;
456     }
457     return true;
458   };
459   for (const auto& arg : fdef->signature().input_arg()) {
460     if (!validate_arg(arg)) {
461       (*out_metadata)->use_multi_device_function_ = false;
462       return Status::OK();
463     }
464   }
465   for (const auto& arg : fdef->signature().output_arg()) {
466     if (!validate_arg(arg)) {
467       (*out_metadata)->use_multi_device_function_ = false;
468       return Status::OK();
469     }
470   }
471   return Status::OK();
472 }
473 
474 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)475 Status CapturedFunction::Create(
476     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
477     const string& argument_name,
478     std::unique_ptr<CapturedFunction>* out_function) {
479   OpInputList inputs;
480   TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
481   std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
482   return Create(ctx, std::move(metadata), std::move(captured_inputs),
483                 out_function);
484 }
485 
486 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> && captured_inputs,std::unique_ptr<CapturedFunction> * out_function)487 Status CapturedFunction::Create(
488     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
489     std::vector<Tensor>&& captured_inputs,
490     std::unique_ptr<CapturedFunction>* out_function) {
491   *out_function = absl::WrapUnique(
492       new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
493   return Status::OK();
494 }
495 
AddToGraph(SerializationContext * ctx,DatasetBase::DatasetGraphDefBuilder * b,std::vector<Node * > * other_arguments,DataTypeVector * other_arguments_types) const496 Status CapturedFunction::AddToGraph(
497     SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
498     std::vector<Node*>* other_arguments,
499     DataTypeVector* other_arguments_types) const {
500   other_arguments->reserve(captured_inputs_.size());
501   other_arguments_types->reserve(captured_inputs_.size());
502   for (const Tensor& t : captured_inputs_) {
503     Node* node;
504     if (ctx->serialize_data_tensors()) {
505       TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
506     } else {
507       TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
508       DCHECK_NE(ctx->input_list(), nullptr);
509       ctx->input_list()->emplace_back(node->name(), t);
510     }
511     other_arguments->emplace_back(node);
512     other_arguments_types->emplace_back(t.dtype());
513   }
514   TF_RETURN_IF_ERROR(
515       b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
516   return Status::OK();
517 }
518 
519 // TODO(b/190831948): Check whether the function creates a resource and if so,
520 // produce a warning.
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)521 Status CapturedFunction::Instantiate(
522     IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
523                               instantiated_captured_function) {
524   // The context's runtime will be used for all subsequent calls.
525   FunctionLibraryRuntime* lib = ctx->flr();
526   FunctionLibraryRuntime::InstantiateOptions inst_opts;
527   inst_opts.lib_def = metadata_->lib_def();
528   inst_opts.create_kernels_eagerly = true;
529   inst_opts.default_device_to_target = metadata_->use_default_device();
530   inst_opts.config_proto =
531       lib->config_proto() ? *lib->config_proto() : ConfigProto();
532   if (!metadata_->use_inter_op_parallelism()) {
533     inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
534   }
535   inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
536   if (!ctx->function_handle_cache()) {
537     // If the caller does not provide a cache, we use the FLR cache.
538     inst_opts.use_function_cache = true;
539   }
540 
541   // We infer the target device from the function library runtime.
542   DCHECK(lib->device() != nullptr);
543   inst_opts.target = lib->device()->name();
544 
545   // Maps from a CompositeDevice name to underlying physical device names.
546   absl::flat_hash_map<string, std::vector<string>> composite_devices;
547 
548   if (inst_opts.is_multi_device_function) {
549     // Compute devices of non-captured inputs.
550     //
551     // We infer the number of non-captured inputs by subtracting the number
552     // of captured inputs from the number of input arguments and we infer the
553     // input devices from the function library runtime.
554     const FunctionDef* fdef;
555     TF_RETURN_IF_ERROR(
556         LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
557     size_t num_non_captured_inputs =
558         fdef->signature().input_arg_size() - captured_inputs_.size();
559     for (size_t i = 0; i < num_non_captured_inputs; ++i) {
560       inst_opts.input_devices.push_back(inst_opts.target);
561     }
562     // Compute devices of captured inputs.
563     // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
564     Device* cpu_device;
565     TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
566     std::unordered_map<int, DtypeAndPartialTensorShape>&
567         input_resource_variable_dtypes_and_shapes =
568             inst_opts.input_resource_dtypes_and_shapes;
569     for (size_t i = 0; i < captured_inputs_.size(); ++i) {
570       const auto& input = captured_inputs_[i];
571       DataType dtype = input.dtype();
572       if (dtype == DT_RESOURCE) {
573         const auto& handles = input.flat<ResourceHandle>();
574         const ResourceHandle& handle0 = handles(0);
575         string composite_device;
576         auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
577         if (iter != fdef->arg_attr().end()) {
578           auto arg_attr = iter->second.attr().find("_composite_device");
579           if (arg_attr != iter->second.attr().end()) {
580             composite_device = arg_attr->second.s();
581           }
582         }
583         if (!composite_device.empty()) {
584           if (composite_devices.find(composite_device) ==
585               composite_devices.end()) {
586             for (int i = 0; i < handles.size(); ++i) {
587               composite_devices[composite_device].push_back(
588                   handles(i).device());
589             }
590           }
591           inst_opts.input_devices.push_back(composite_device);
592         } else {
593           inst_opts.input_devices.push_back(handle0.device());
594         }
595         const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
596         // Set dtypes and shapes for resource variable inputs.
597         if (!dtypes_and_shapes.empty()) {
598           input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
599                                                     i] =
600               dtypes_and_shapes.at(0);
601         }
602       } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
603         inst_opts.input_devices.push_back(cpu_device->name());
604       } else {
605         // Fall back to using the function library runtime device.
606         inst_opts.input_devices.push_back(inst_opts.target);
607       }
608     }
609 
610     for (const auto& it : composite_devices) {
611       inst_opts.composite_devices[it.first] = &it.second;
612     }
613 
614     for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
615       inst_opts.output_devices.push_back(inst_opts.target);
616     }
617 
618 #if !defined(IS_MOBILE_PLATFORM)
619     grappler::GrapplerItem::OptimizationOptions optimization_options;
620     optimization_options.allow_pruning_stateful_and_dataset_ops = false;
621     ConfigProto config_proto = inst_opts.config_proto;
622     // Layout optimizations are excluded because they assume that ops without
623     // explicit device assignment will be placed on GPU (if available) but
624     // that's not the case for operations within tf.data functions.
625     config_proto.mutable_graph_options()
626         ->mutable_rewrite_options()
627         ->set_layout_optimizer(RewriterConfig::OFF);
628     // TODO(b/120437209): Re-enable constant folding.
629     config_proto.mutable_graph_options()
630         ->mutable_rewrite_options()
631         ->set_constant_folding(RewriterConfig::OFF);
632     inst_opts.optimize_graph_fn =
633         std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
634                   std::placeholders::_2, std::placeholders::_3,
635                   std::placeholders::_4, std::placeholders::_5,
636                   std::move(config_proto), fdef->signature().name(),
637                   std::move(optimization_options), std::placeholders::_6);
638 #endif  // !IS_MOBILE_PLATFORM
639   }
640 
641   FunctionLibraryRuntime::Handle f_handle;
642   if (ctx->function_handle_cache()) {
643     TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
644         metadata_->func().name(), AttrSlice(&metadata_->func().attr()),
645         inst_opts, &f_handle));
646   } else {
647     TF_RETURN_IF_ERROR(lib->Instantiate(metadata_->func().name(),
648                                         AttrSlice(&metadata_->func().attr()),
649                                         inst_opts, &f_handle));
650   }
651 
652   DataTypeVector ret_types;
653   TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
654 
655   bool is_multi_device;
656   TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
657   return InstantiatedCapturedFunction::Create(
658       lib, f_handle, std::move(ret_types), *ctx->runner(), this,
659       is_multi_device, instantiated_captured_function);
660 }
661 
CheckExternalState() const662 Status CapturedFunction::CheckExternalState() const {
663   for (const auto& name : lib_def()->ListFunctionNames()) {
664     TF_RETURN_IF_ERROR(
665         IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
666   }
667   return Status::OK();
668 }
669 
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> captured_inputs)670 CapturedFunction::CapturedFunction(
671     std::shared_ptr<const FunctionMetadata> metadata,
672     std::vector<Tensor> captured_inputs)
673     : metadata_(std::move(metadata)),
674       captured_inputs_(std::move(captured_inputs)) {}
675 
IsMultiDevice(IteratorContext * ctx,bool * is_multi_device) const676 Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
677                                        bool* is_multi_device) const {
678   if (!metadata_->use_multi_device_function()) {
679     *is_multi_device = false;
680     return Status::OK();
681   }
682 
683   const FunctionDef* fdef;
684   TF_RETURN_IF_ERROR(
685       LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
686 
687   Device* current_device = ctx->flr()->device();
688   DeviceType current_device_type(current_device->device_type());
689   DeviceNameUtils::ParsedName current_device_name;
690   if (!DeviceNameUtils::ParseFullName(current_device->name(),
691                                       &current_device_name)) {
692     return errors::InvalidArgument("Failed to parse device name: ",
693                                    current_device->name());
694   }
695 
696   // Check if any of the captured inputs are placed on a device not compatible
697   // with the current device. For non-captured inputs, we assume they are placed
698   // on the current device.
699   for (const auto& input : captured_inputs_) {
700     DataType dtype = input.dtype();
701     if (dtype == DT_RESOURCE) {
702       const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
703       DeviceNameUtils::ParsedName resource_device_name;
704       if (!DeviceNameUtils::ParseFullName(handle.device(),
705                                           &resource_device_name)) {
706         return errors::InvalidArgument("Failed to parse device name: ",
707                                        handle.device());
708       }
709       if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
710                                                   resource_device_name)) {
711         *is_multi_device = true;
712         return Status::OK();
713       }
714     }
715   }
716 
717   // Check if all ops could be placed on the current device.
718   for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
719     const FunctionDef* fdef;
720     TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
721     for (const auto& node : fdef->node_def()) {
722       // Check if the op has a kernel available for the current device.
723       if (!KernelDefAvailable(current_device_type, node)) {
724         *is_multi_device = true;
725         return Status::OK();
726       }
727       // If the op has a requested device, check if the requested device is
728       // compatible with the current device.
729       if (!node.device().empty()) {
730         DeviceNameUtils::ParsedName node_device_name;
731         if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
732           return errors::InvalidArgument("Failed to parse device name: ",
733                                          node.device());
734         }
735         if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
736                                                     node_device_name)) {
737           *is_multi_device = true;
738           return Status::OK();
739         }
740       }
741     }
742   }
743 
744   *is_multi_device = false;
745   return Status::OK();
746 }
747 
748 /* static */
Create(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device,std::unique_ptr<InstantiatedCapturedFunction> * out_function)749 Status InstantiatedCapturedFunction::Create(
750     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
751     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
752     CapturedFunction* captured_func, bool is_multi_device,
753     std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
754   out_function->reset(new InstantiatedCapturedFunction(
755       lib, f_handle, ret_types, runner, captured_func, is_multi_device));
756   return Status::OK();
757 }
758 
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device)759 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
760     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
761     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
762     CapturedFunction* captured_func, bool is_multi_device)
763     : lib_(lib),
764       f_handle_(f_handle),
765       ret_types_(std::move(ret_types)),
766       captured_runner_(std::move(runner)),
767       captured_func_(captured_func),
768       is_multi_device_(is_multi_device) {}
769 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const770 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
771                                          std::vector<Tensor>&& args,
772                                          std::vector<Tensor>* rets) const {
773   return Run(ctx, std::move(args), rets, /*node=*/nullptr);
774 }
775 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const776 Status InstantiatedCapturedFunction::Run(
777     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
778     const std::shared_ptr<model::Node>& node) const {
779   auto& info = captured_func_->short_circuit_info();
780   if (!info.indices.empty()) {
781     return RunShortCircuit(info, std::move(args), captured_func_, rets);
782   }
783 
784   FunctionLibraryRuntime::Options f_opts;
785   ScopedStepContainer step_container(
786       f_opts.step_id, [this](const string& name) {
787         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
788       });
789   f_opts.step_container = &step_container;
790   f_opts.runner = ctx->runner();
791   f_opts.create_rendezvous = ShouldCreateRendezvous();
792   CancellationManager cancellation_manager(ctx->cancellation_manager());
793   f_opts.cancellation_manager = &cancellation_manager;
794   f_opts.collective_executor = ctx->collective_executor();
795 
796   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
797   if (node || ctx->stats_aggregator()) {
798     stats_collector = std::make_shared<SimpleStepStatsCollector>();
799   }
800   const bool collect_usage =
801       node && ctx->model() && ctx->model()->collect_resource_usage();
802   f_opts.stats_collector = stats_collector.get();
803 
804   OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
805                            ret_types_);
806   profiler::TraceMe activity(
807       [&] {
808         return profiler::TraceMeEncode("InstantiatedCapturedFunction::Run",
809                                        {{"id", f_opts.step_id}});
810       },
811       profiler::TraceMeLevel::kInfo);
812   if (node) {
813     // Resource usage for function execution is gathered from the executor.
814     // TODO(jsimsa): Factor out common code for Run, RunAsync, and
815     // RunWithBorrowedArguments
816     if (collect_usage) node->record_stop(EnvTime::NowNanos());
817     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
818     if (ctx->stats_aggregator()) {
819       string prefix_with_func_name = strings::StrCat(
820           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
821       ctx->stats_aggregator()->AddToHistogram(
822           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
823           {static_cast<float>(stats_collector->processing_time())},
824           node->num_elements());
825     }
826     node->add_processing_time(stats_collector->processing_time());
827     if (collect_usage) node->record_start(EnvTime::NowNanos());
828   } else {
829     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
830   }
831   return frame.ConsumeRetvals(rets);
832 }
833 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * ret) const834 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
835     IteratorContext* ctx, const std::vector<Tensor>& args,
836     std::vector<Tensor>* ret) const {
837   return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
838 }
839 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const840 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
841     IteratorContext* ctx, const std::vector<Tensor>& args,
842     std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
843   auto& info = captured_func_->short_circuit_info();
844   if (!info.indices.empty()) {
845     return RunShortCircuit(info, args, captured_func_, rets);
846   }
847 
848   FunctionLibraryRuntime::Options f_opts;
849   ScopedStepContainer step_container(
850       f_opts.step_id, [this](const string& name) {
851         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
852       });
853   f_opts.step_container = &step_container;
854   f_opts.runner = ctx->runner();
855   f_opts.create_rendezvous = ShouldCreateRendezvous();
856   CancellationManager cancellation_manager(ctx->cancellation_manager());
857   f_opts.cancellation_manager = &cancellation_manager;
858   f_opts.collective_executor = ctx->collective_executor();
859 
860   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
861   if (node || ctx->stats_aggregator()) {
862     stats_collector = std::make_shared<SimpleStepStatsCollector>();
863   }
864   const bool collect_usage =
865       node && ctx->model() && ctx->model()->collect_resource_usage();
866   f_opts.stats_collector = stats_collector.get();
867 
868   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
869                               ret_types_);
870   profiler::TraceMe activity(
871       [&] {
872         return profiler::TraceMeEncode(
873             "InstantiatedCapturedFunction::RunWithBorrowedArgs",
874             {{"id", f_opts.step_id}});
875       },
876       profiler::TraceMeLevel::kInfo);
877   if (node) {
878     // Resource usage for function execution is gathered from the executor.
879     if (collect_usage) node->record_stop(EnvTime::NowNanos());
880     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
881     if (ctx->stats_aggregator()) {
882       string prefix_with_func_name = strings::StrCat(
883           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
884       ctx->stats_aggregator()->AddToHistogram(
885           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
886           {static_cast<float>(stats_collector->processing_time())},
887           node->num_elements());
888     }
889     node->add_processing_time(stats_collector->processing_time());
890     if (collect_usage) node->record_start(EnvTime::NowNanos());
891   } else {
892     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
893   }
894   return frame.ConsumeRetvals(rets);
895 }
896 
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)897 Status InstantiatedCapturedFunction::RunInstantiated(
898     const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
899   auto& info = captured_func_->short_circuit_info();
900   if (!info.indices.empty()) {
901     return RunShortCircuit(info, args, captured_func_, rets);
902   }
903 
904   FunctionLibraryRuntime::Options f_opts;
905   ScopedStepContainer step_container(
906       f_opts.step_id, [this](const string& name) {
907         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
908       });
909   f_opts.step_container = &step_container;
910   f_opts.runner = &captured_runner_;
911   f_opts.create_rendezvous = ShouldCreateRendezvous();
912   CancellationManager cancellation_manager;
913   f_opts.cancellation_manager = &cancellation_manager;
914 
915   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
916                               ret_types_);
917   profiler::TraceMe activity(
918       [&] {
919         return profiler::TraceMeEncode(
920             "InstantiatedCapturedFunction::RunInstantiated",
921             {{"id", f_opts.step_id}});
922       },
923       profiler::TraceMeLevel::kInfo);
924   TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
925   return frame.ConsumeRetvals(rets);
926 }
927 
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const std::shared_ptr<model::Node> & node) const928 void InstantiatedCapturedFunction::RunAsync(
929     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
930     FunctionLibraryRuntime::DoneCallback done,
931     const std::shared_ptr<model::Node>& node) const {
932   auto& info = captured_func_->short_circuit_info();
933   if (!info.indices.empty()) {
934     // Run the `done` callback on a threadpool thread, because it will
935     // potentially do a non-trivial amount of (e.g. copying) work, and we may
936     // want to run that concurrently with the next invocation.
937     Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
938     (*ctx->runner())(
939         std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
940                   std::move(done)));
941     return;
942   }
943 
944   // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
945   // be deleted before `done` is called. Take care not to capture `ctx` in any
946   // code that may execute asynchronously in this function.
947   OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
948       std::move(args), &captured_func_->captured_inputs(), ret_types_);
949 
950   FunctionLibraryRuntime::Options f_opts;
951   ResourceMgr* resource_mgr = lib_->device()->resource_manager();
952   ScopedStepContainer* step_container = new ScopedStepContainer(
953       f_opts.step_id, [resource_mgr](const string& name) {
954         resource_mgr->Cleanup(name).IgnoreError();
955       });
956   f_opts.step_container = step_container;
957   f_opts.runner = ctx->runner();
958   f_opts.create_rendezvous = ShouldCreateRendezvous();
959   auto cancellation_manager =
960       absl::make_unique<CancellationManager>(ctx->cancellation_manager());
961   f_opts.cancellation_manager = cancellation_manager.get();
962   f_opts.collective_executor = ctx->collective_executor();
963 
964   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
965   if (node || ctx->stats_aggregator()) {
966     stats_collector = std::make_shared<SimpleStepStatsCollector>();
967   }
968   const bool collect_usage =
969       node && ctx->model() && ctx->model()->collect_resource_usage();
970   f_opts.stats_collector = stats_collector.get();
971 
972   // Transfer ownership of the cancellation manager to `callback`.
973   CancellationManager* raw_cancellation_manager =
974       cancellation_manager.release();
975   auto callback = std::bind(
976       [this, rets, step_container, raw_cancellation_manager, frame, node,
977        collect_usage](
978           const FunctionLibraryRuntime::DoneCallback& done,
979           IteratorContext* ctx,
980           const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
981           // Begin unbound arguments.
982           Status s) {
983         delete step_container;
984         delete raw_cancellation_manager;
985         if (s.ok()) {
986           s = frame->ConsumeRetvals(rets);
987         }
988         delete frame;
989         if (node) {
990           // TODO(b/129085499) Utilize the `node_name` which would be unique
991           // than the prefix for the function execution time statistics.
992           // prefix_with_func_name would then be node_name + func_name.
993           if (ctx->stats_aggregator()) {
994             string prefix_with_func_name =
995                 strings::StrCat(node->name(), stats_utils::kDelimiter,
996                                 captured_func_->func().name());
997             ctx->stats_aggregator()->AddToHistogram(
998                 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
999                 {static_cast<float>(stats_collector->processing_time())},
1000                 node->num_elements());
1001           }
1002           node->add_processing_time(stats_collector->processing_time());
1003         }
1004         if (collect_usage) {
1005           node->record_start(EnvTime::NowNanos());
1006         }
1007         done(s);
1008         if (collect_usage) {
1009           node->record_stop(EnvTime::NowNanos());
1010         }
1011       },
1012       std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1013 
1014   profiler::TraceMe activity(
1015       [&] {
1016         return profiler::TraceMeEncode("InstantiatedCapturedFunction::RunAsync",
1017                                        {{"id", f_opts.step_id}});
1018       },
1019       profiler::TraceMeLevel::kInfo);
1020   // Stop the usage collection before calling `Run()` because `callback` may
1021   // be executed synchronously, and so the `node->record_start()` call within
1022   // `callback` would violate nesting.
1023   if (collect_usage) node->record_stop(EnvTime::NowNanos());
1024   lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1025   if (collect_usage) node->record_start(EnvTime::NowNanos());
1026 }
1027 
ShouldCreateRendezvous() const1028 bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1029   // Rendezvous should only be created by the FLR for non-CPU single-device
1030   // functions. For multi-device functions the appropriate rendezvous will be
1031   // created by the process FLR.
1032   return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1033 }
1034 
1035 }  // namespace data
1036 }  // namespace tensorflow
1037