• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/captured_function.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/step_stats_collector.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/framework/function_handle_cache.h"
23 #include "tensorflow/core/framework/stats_aggregator.h"
24 #include "tensorflow/core/kernels/data/stats_utils.h"
25 #include "tensorflow/core/lib/gtl/optional.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/notification.h"
29 
30 namespace tensorflow {
31 namespace data {
32 
33 namespace {
34 
35 // Simplistic implementation of the `StepStatsCollectorInterface` that only
36 // cares about collecting the CPU time needed to execute a captured function.
37 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
38  public:
IncrementProcessingTime(int64 delta)39   void IncrementProcessingTime(int64 delta) {
40     mutex_lock l(mu_);
41     processing_time_ += delta;
42   }
43 
CreateNodeExecStats(const Node * node)44   NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
45     return new SimpleNodeExecStats(this);
46   }
47 
ReportAllocsOnResourceExhausted(const string & err)48   string ReportAllocsOnResourceExhausted(const string& err) override {
49     return "";
50   }
51 
processing_time()52   int64 processing_time() {
53     tf_shared_lock l(mu_);
54     return processing_time_;
55   }
56 
57  private:
58   class SimpleNodeExecStats : public NodeExecStatsInterface {
59    public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)60     explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
61         : step_stats_collector_(step_stats_collector) {}
62 
Done(const string & device)63     void Done(const string& device) override {
64       step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
65                                                      start_time_ns_);
66       delete this;
67     }
68 
RecordExecutorStarted()69     void RecordExecutorStarted() override {
70       start_time_ns_ = Env::Default()->NowNanos();
71     }
72 
RecordComputeStarted()73     void RecordComputeStarted() override {}
74 
RecordComputeEnded()75     void RecordComputeEnded() override {}
76 
RecordExecutorEnded()77     void RecordExecutorEnded() override {
78       end_time_ns_ = Env::Default()->NowNanos();
79     }
80 
TrackAllocations() const81     bool TrackAllocations() const override { return false; }
82 
SetMemory(OpKernelContext * ctx)83     void SetMemory(OpKernelContext* ctx) override {}
84 
SetOutput(int slot,const Tensor * tensor)85     void SetOutput(int slot, const Tensor* tensor) override {}
86 
SetReferencedTensors(const TensorReferenceVector & tensors)87     void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
88 
SetScheduled(int64 nanos)89     void SetScheduled(int64 nanos) override {}
90 
91    private:
92     int64 start_time_ns_ = 0;
93     int64 end_time_ns_ = 0;
94     SimpleStepStatsCollector* step_stats_collector_;  // Not owned.
95   };
96 
97   mutex mu_;
98   int64 processing_time_ GUARDED_BY(mu_) = 0;
99 };
100 
101 }  // namespace
102 
103 /* static */
Create(const NameAttrList & func,OpKernelContext * ctx,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)104 Status CapturedFunction::Create(
105     const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
106     std::unique_ptr<CapturedFunction>* out_function) {
107   return CapturedFunction::Create(func, ctx, argument_name, true, out_function);
108 }
109 
Create(const NameAttrList & func,OpKernelContext * ctx,const string & argument_name,bool use_inter_op_parallelism,std::unique_ptr<CapturedFunction> * out_function)110 Status CapturedFunction::Create(
111     const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
112     bool use_inter_op_parallelism,
113     std::unique_ptr<CapturedFunction>* out_function) {
114   OpInputList inputs;
115   TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
116   std::vector<Tensor> arguments(inputs.begin(), inputs.end());
117   *out_function = absl::WrapUnique(new CapturedFunction(
118       func, std::move(arguments), use_inter_op_parallelism));
119   return Status::OK();
120 }
121 
Create(const NameAttrList & func,OpKernelContext * ctx,std::vector<Tensor> && captured_inputs,bool use_inter_op_parallelism,std::unique_ptr<CapturedFunction> * out_function)122 Status CapturedFunction::Create(
123     const NameAttrList& func, OpKernelContext* ctx,
124     std::vector<Tensor>&& captured_inputs, bool use_inter_op_parallelism,
125     std::unique_ptr<CapturedFunction>* out_function) {
126   *out_function = absl::WrapUnique(new CapturedFunction(
127       func, std::move(captured_inputs), use_inter_op_parallelism));
128   return Status::OK();
129 }
130 
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)131 Status CapturedFunction::Instantiate(
132     IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
133                               instantiated_captured_function) {
134   // The context's runtime will be used for all subsequent calls.
135   FunctionLibraryRuntime* lib = ctx->lib();
136   FunctionLibraryRuntime::InstantiateOptions inst_opts;
137   inst_opts.overlay_lib = ctx->function_library().get();
138   inst_opts.create_kernels_eagerly = true;
139   if (!use_inter_op_parallelism_) {
140     inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
141   }
142 
143   FunctionLibraryRuntime::Handle f_handle;
144   TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
145       func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle));
146   const FunctionBody* fbody = lib->GetFunctionBody(f_handle);
147   if (fbody == nullptr) {
148     return errors::Internal("Failed to instantiate function body.");
149   }
150 
151   DataTypeVector ret_types;
152   for (const auto& ret_type : fbody->ret_types) {
153     ret_types.push_back(ret_type);
154   }
155 
156   *instantiated_captured_function =
157       absl::WrapUnique<InstantiatedCapturedFunction>(
158           new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
159                                            *ctx->runner(), this));
160   return Status::OK();
161 }
162 
163 namespace {
164 class CallFrameBase : public CallFrameInterface {
165  public:
CallFrameBase(DataTypeSlice ret_types)166   explicit CallFrameBase(DataTypeSlice ret_types)
167       : ret_types_(ret_types), retvals_(ret_types.size()) {}
168 
169   // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)170   Status ConsumeRetvals(std::vector<Tensor>* retvals) {
171     retvals->reserve(retvals_.size());
172     int i = 0;
173     for (auto&& val : retvals_) {
174       if (!val) {
175         return errors::Internal("No return value for index ", i, ".");
176       }
177       retvals->emplace_back(std::move(val.value()));
178       ++i;
179     }
180     return Status::OK();
181   }
182 
num_retvals() const183   size_t num_retvals() const override { return retvals_.size(); }
184 
185   // Callee methods.
SetRetval(int index,const Tensor & val)186   Status SetRetval(int index, const Tensor& val) override {
187     if (index < retvals_.size() && val.dtype() == ret_types_[index] &&
188         !retvals_[index]) {
189       retvals_[index] = val;
190       return Status::OK();
191     } else if (index >= retvals_.size()) {
192       return errors::InvalidArgument("Return value ", index,
193                                      " is out of range.");
194     } else if (val.dtype() != ret_types_[index]) {
195       return errors::InvalidArgument("Expected type ",
196                                      DataTypeString(ret_types_[index]),
197                                      " for return value ", index, " but got ",
198                                      DataTypeString(val.dtype()), ".");
199     } else {
200       return errors::Internal("Attempted to set return value ", index,
201                               " more than once.");
202     }
203   }
204 
205  private:
206   DataTypeSlice ret_types_;
207   std::vector<gtl::optional<Tensor>> retvals_;
208   TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
209 };
210 
211 class OwnedArgsCallFrame : public CallFrameBase {
212  public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)213   OwnedArgsCallFrame(std::vector<Tensor>&& args,
214                      const std::vector<Tensor>* captured_inputs,
215                      DataTypeSlice ret_types)
216       : CallFrameBase(ret_types),
217         args_(std::move(args)),
218         captured_inputs_(captured_inputs) {}
219 
num_args() const220   size_t num_args() const override {
221     return args_.size() + captured_inputs_->size();
222   }
223 
224   // Callee methods.
GetArg(int index,Tensor * val) const225   Status GetArg(int index, Tensor* val) const override {
226     if (index < args_.size() && args_[index].IsInitialized()) {
227       // TODO(mrry): Consider making `CallFrameInterface::GetArg` non-const in
228       // order to be able to `std::move(args_[index])` into `*val`.
229       *val = args_[index];
230       return Status::OK();
231     } else if (index < args_.size() + captured_inputs_->size()) {
232       *val = (*captured_inputs_)[index - args_.size()];
233       return Status::OK();
234     } else if (index >= args_.size() + captured_inputs_->size()) {
235       return errors::InvalidArgument("Argument ", index, " is out of range.");
236     } else {
237       return errors::Internal("Attempted to get argument ", index,
238                               " more than once.");
239     }
240   }
241 
242  private:
243   std::vector<Tensor> args_;
244   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
245 };
246 
247 class BorrowedArgsCallFrame : public CallFrameBase {
248  public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)249   BorrowedArgsCallFrame(const std::vector<Tensor>& args,
250                         const std::vector<Tensor>* captured_inputs,
251                         DataTypeSlice ret_types)
252       : CallFrameBase(ret_types),
253         args_(args),
254         captured_inputs_(captured_inputs) {}
255 
num_args() const256   size_t num_args() const override {
257     return args_.size() + captured_inputs_->size();
258   }
259 
260   // Callee methods.
GetArg(int index,Tensor * val) const261   Status GetArg(int index, Tensor* val) const override {
262     if (index < args_.size() && args_[index].IsInitialized()) {
263       *val = args_[index];
264       return Status::OK();
265     } else if (index < args_.size() + captured_inputs_->size()) {
266       *val = (*captured_inputs_)[index - args_.size()];
267       return Status::OK();
268     } else if (index >= args_.size() + captured_inputs_->size()) {
269       return errors::InvalidArgument("Argument ", index, " is out of range.");
270     } else {
271       return errors::Internal("Attempted to get argument ", index,
272                               " more than once.");
273     }
274   }
275 
276  private:
277   const std::vector<Tensor>& args_;                   // Not owned.
278   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
279 };
280 
281 }  // namespace
282 
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func)283 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
284     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
285     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
286     CapturedFunction* captured_func)
287     : lib_(lib),
288       f_handle_(f_handle),
289       ret_types_(std::move(ret_types)),
290       captured_runner_(std::move(runner)),
291       captured_func_(captured_func) {}
292 
293 // NOTE: We don't release f_handle_ here and instead delegate the function
294 // handle releasing to the FunctionHandleCache. This is because in some cases
295 // (RepeatDatasetOp in particular), we want to keep the function state (e.g.
296 // random number generator) even after the Iterator is reset after going through
297 // one epoch.
~InstantiatedCapturedFunction()298 InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {}
299 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const300 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
301                                          std::vector<Tensor>&& args,
302                                          std::vector<Tensor>* rets) const {
303   FunctionLibraryRuntime::Options f_opts;
304   f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
305   ScopedStepContainer step_container(
306       f_opts.step_id, [this](const string& name) {
307         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
308       });
309   f_opts.step_container = &step_container;
310   f_opts.runner = ctx->runner();
311   if (lib_->device()->device_type() != DEVICE_CPU) {
312     f_opts.create_rendezvous = true;
313   }
314   // TODO(mrry): Add cancellation manager support to IteratorContext
315   // so that we can cancel running map functions. The local
316   // cancellation manager here is created so that we can run kernels
317   // (such as queue kernels) that depend on the non-nullness of
318   // `OpKernelContext::cancellation_manager()`, but additional effort
319   // will be required to plumb it through the `IteratorContext`.
320   CancellationManager c_mgr;
321   f_opts.cancellation_manager = &c_mgr;
322 
323   OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
324                            ret_types_);
325   Notification n;
326   Status s;
327   lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
328     s.Update(func_status);
329     n.Notify();
330   });
331   n.WaitForNotification();
332   TF_RETURN_IF_ERROR(s);
333   return frame.ConsumeRetvals(rets);
334 }
335 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets) const336 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
337     IteratorContext* ctx, const std::vector<Tensor>& args,
338     std::vector<Tensor>* rets) const {
339   FunctionLibraryRuntime::Options f_opts;
340   f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
341   ScopedStepContainer step_container(
342       f_opts.step_id, [this](const string& name) {
343         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
344       });
345   f_opts.step_container = &step_container;
346   f_opts.runner = ctx->runner();
347   if (lib_->device()->device_type() != DEVICE_CPU) {
348     f_opts.create_rendezvous = true;
349   }
350   // TODO(mrry): Add cancellation manager support to IteratorContext
351   // so that we can cancel running map functions. The local
352   // cancellation manager here is created so that we can run kernels
353   // (such as queue kernels) that depend on the non-nullness of
354   // `OpKernelContext::cancellation_manager()`, but additional effort
355   // will be required to plumb it through the `IteratorContext`.
356   CancellationManager c_mgr;
357   f_opts.cancellation_manager = &c_mgr;
358 
359   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
360                               ret_types_);
361   Notification n;
362   Status s;
363 
364   lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
365     s.Update(func_status);
366     n.Notify();
367   });
368   n.WaitForNotification();
369   TF_RETURN_IF_ERROR(s);
370   return frame.ConsumeRetvals(rets);
371 }
372 
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)373 Status InstantiatedCapturedFunction::RunInstantiated(
374     const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
375   FunctionLibraryRuntime::Options f_opts;
376   f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
377   ScopedStepContainer step_container(
378       f_opts.step_id, [this](const string& name) {
379         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
380       });
381   f_opts.step_container = &step_container;
382   f_opts.runner = &captured_runner_;
383   if (lib_->device()->device_type() != DEVICE_CPU) {
384     f_opts.create_rendezvous = true;
385   }
386   // TODO(mrry): Add cancellation manager support to IteratorContext
387   // so that we can cancel running map functions. The local
388   // cancellation manager here is created so that we can run kernels
389   // (such as queue kernels) that depend on the non-nullness of
390   // `OpKernelContext::cancellation_manager()`, but additional effort
391   // will be required to plumb it through the `IteratorContext`.
392   CancellationManager c_mgr;
393   f_opts.cancellation_manager = &c_mgr;
394 
395   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
396                               ret_types_);
397   Notification n;
398   Status s;
399 
400   lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
401     s.Update(func_status);
402     n.Notify();
403   });
404   n.WaitForNotification();
405   TF_RETURN_IF_ERROR(s);
406   return frame.ConsumeRetvals(rets);
407 }
408 
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const string & prefix) const409 void InstantiatedCapturedFunction::RunAsync(
410     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
411     FunctionLibraryRuntime::DoneCallback done, const string& prefix) const {
412   // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
413   // be deleted before `done` is called. Take care not to capture `ctx` in any
414   // code that may execute asynchronously in this function.
415   OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
416       std::move(args), &captured_func_->captured_inputs(), ret_types_);
417 
418   FunctionLibraryRuntime::Options f_opts;
419   f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
420   ResourceMgr* resource_mgr = lib_->device()->resource_manager();
421   ScopedStepContainer* step_container = new ScopedStepContainer(
422       f_opts.step_id, [resource_mgr](const string& name) {
423         resource_mgr->Cleanup(name).IgnoreError();
424       });
425   f_opts.step_container = step_container;
426   f_opts.runner = ctx->runner();
427   if (lib_->device()->device_type() != DEVICE_CPU) {
428     f_opts.create_rendezvous = true;
429   }
430   // TODO(mrry): Add cancellation manager support to IteratorContext
431   // so that we can cancel running map functions. The local
432   // cancellation manager here is created so that we can run kernels
433   // (such as queue kernels) that depend on the non-nullness of
434   // `OpKernelContext::cancellation_manager()`, but additional effort
435   // will be required to plumb it through the `IteratorContext`.
436   CancellationManager* c_mgr = new CancellationManager();
437   f_opts.cancellation_manager = c_mgr;
438   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
439   if (ctx->model() || ctx->stats_aggregator()) {
440     stats_collector = absl::make_unique<SimpleStepStatsCollector>();
441   }
442   f_opts.stats_collector = stats_collector.get();
443 
444   auto callback = std::bind(
445       [this, rets, step_container, c_mgr, frame](
446           const FunctionLibraryRuntime::DoneCallback& done,
447           const std::shared_ptr<model::Model>& model,
448           const std::shared_ptr<StatsAggregator>& stats_aggregator,
449           const string& prefix,
450           const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
451           // Begin unbound arguments.
452           Status s) {
453         delete step_container;
454         delete c_mgr;
455         if (s.ok()) {
456           s = frame->ConsumeRetvals(rets);
457         }
458         delete frame;
459         // TODO(shivaniagrawal): add the dataset name containing this function,
460         // make it dataset()->node_name() + captured_func_->func().name().
461         if (stats_aggregator) {
462           string prefix_with_func_name = strings::StrCat(
463               str_util::Split(prefix, "::", str_util::SkipEmpty()).back(),
464               "::", captured_func_->func().name());
465           stats_aggregator->AddToHistogram(
466               stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
467               {static_cast<float>(stats_collector->processing_time())});
468         }
469         if (model) {
470           model->AddProcessingTime(prefix, stats_collector->processing_time());
471           model->RecordStart(prefix, false /* stop_output */);
472         }
473         done(s);
474         if (model) {
475           model->RecordStop(prefix, false /* start_output */);
476         }
477       },
478       std::move(done), ctx->model(), ctx->stats_aggregator(), prefix,
479       std::move(stats_collector), std::placeholders::_1);
480 
481   lib_->Run(f_opts, f_handle_, frame, std::move(callback));
482 }
483 
CapturedFunction(const NameAttrList & func,std::vector<Tensor> captured_inputs,bool use_inter_op_parallelism)484 CapturedFunction::CapturedFunction(const NameAttrList& func,
485                                    std::vector<Tensor> captured_inputs,
486                                    bool use_inter_op_parallelism)
487     : func_(func),
488       captured_inputs_(std::move(captured_inputs)),
489       use_inter_op_parallelism_(use_inter_op_parallelism) {}
490 
491 }  // namespace data
492 }  // namespace tensorflow
493