1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/captured_function.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/step_stats_collector.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/framework/function_handle_cache.h"
23 #include "tensorflow/core/framework/stats_aggregator.h"
24 #include "tensorflow/core/kernels/data/stats_utils.h"
25 #include "tensorflow/core/lib/gtl/optional.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/notification.h"
29
30 namespace tensorflow {
31 namespace data {
32
33 namespace {
34
35 // Simplistic implementation of the `StepStatsCollectorInterface` that only
36 // cares about collecting the CPU time needed to execute a captured function.
37 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
38 public:
IncrementProcessingTime(int64 delta)39 void IncrementProcessingTime(int64 delta) {
40 mutex_lock l(mu_);
41 processing_time_ += delta;
42 }
43
CreateNodeExecStats(const Node * node)44 NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
45 return new SimpleNodeExecStats(this);
46 }
47
ReportAllocsOnResourceExhausted(const string & err)48 string ReportAllocsOnResourceExhausted(const string& err) override {
49 return "";
50 }
51
processing_time()52 int64 processing_time() {
53 tf_shared_lock l(mu_);
54 return processing_time_;
55 }
56
57 private:
58 class SimpleNodeExecStats : public NodeExecStatsInterface {
59 public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)60 explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
61 : step_stats_collector_(step_stats_collector) {}
62
Done(const string & device)63 void Done(const string& device) override {
64 step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
65 start_time_ns_);
66 delete this;
67 }
68
RecordExecutorStarted()69 void RecordExecutorStarted() override {
70 start_time_ns_ = Env::Default()->NowNanos();
71 }
72
RecordComputeStarted()73 void RecordComputeStarted() override {}
74
RecordComputeEnded()75 void RecordComputeEnded() override {}
76
RecordExecutorEnded()77 void RecordExecutorEnded() override {
78 end_time_ns_ = Env::Default()->NowNanos();
79 }
80
TrackAllocations() const81 bool TrackAllocations() const override { return false; }
82
SetMemory(OpKernelContext * ctx)83 void SetMemory(OpKernelContext* ctx) override {}
84
SetOutput(int slot,const Tensor * tensor)85 void SetOutput(int slot, const Tensor* tensor) override {}
86
SetReferencedTensors(const TensorReferenceVector & tensors)87 void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
88
SetScheduled(int64 nanos)89 void SetScheduled(int64 nanos) override {}
90
91 private:
92 int64 start_time_ns_ = 0;
93 int64 end_time_ns_ = 0;
94 SimpleStepStatsCollector* step_stats_collector_; // Not owned.
95 };
96
97 mutex mu_;
98 int64 processing_time_ GUARDED_BY(mu_) = 0;
99 };
100
101 } // namespace
102
103 /* static */
Create(const NameAttrList & func,OpKernelContext * ctx,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)104 Status CapturedFunction::Create(
105 const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
106 std::unique_ptr<CapturedFunction>* out_function) {
107 return CapturedFunction::Create(func, ctx, argument_name, true, out_function);
108 }
109
Create(const NameAttrList & func,OpKernelContext * ctx,const string & argument_name,bool use_inter_op_parallelism,std::unique_ptr<CapturedFunction> * out_function)110 Status CapturedFunction::Create(
111 const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
112 bool use_inter_op_parallelism,
113 std::unique_ptr<CapturedFunction>* out_function) {
114 OpInputList inputs;
115 TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
116 std::vector<Tensor> arguments(inputs.begin(), inputs.end());
117 *out_function = absl::WrapUnique(new CapturedFunction(
118 func, std::move(arguments), use_inter_op_parallelism));
119 return Status::OK();
120 }
121
Create(const NameAttrList & func,OpKernelContext * ctx,std::vector<Tensor> && captured_inputs,bool use_inter_op_parallelism,std::unique_ptr<CapturedFunction> * out_function)122 Status CapturedFunction::Create(
123 const NameAttrList& func, OpKernelContext* ctx,
124 std::vector<Tensor>&& captured_inputs, bool use_inter_op_parallelism,
125 std::unique_ptr<CapturedFunction>* out_function) {
126 *out_function = absl::WrapUnique(new CapturedFunction(
127 func, std::move(captured_inputs), use_inter_op_parallelism));
128 return Status::OK();
129 }
130
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)131 Status CapturedFunction::Instantiate(
132 IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
133 instantiated_captured_function) {
134 // The context's runtime will be used for all subsequent calls.
135 FunctionLibraryRuntime* lib = ctx->lib();
136 FunctionLibraryRuntime::InstantiateOptions inst_opts;
137 inst_opts.overlay_lib = ctx->function_library().get();
138 inst_opts.create_kernels_eagerly = true;
139 if (!use_inter_op_parallelism_) {
140 inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
141 }
142
143 FunctionLibraryRuntime::Handle f_handle;
144 TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
145 func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle));
146 const FunctionBody* fbody = lib->GetFunctionBody(f_handle);
147 if (fbody == nullptr) {
148 return errors::Internal("Failed to instantiate function body.");
149 }
150
151 DataTypeVector ret_types;
152 for (const auto& ret_type : fbody->ret_types) {
153 ret_types.push_back(ret_type);
154 }
155
156 *instantiated_captured_function =
157 absl::WrapUnique<InstantiatedCapturedFunction>(
158 new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
159 *ctx->runner(), this));
160 return Status::OK();
161 }
162
163 namespace {
164 class CallFrameBase : public CallFrameInterface {
165 public:
CallFrameBase(DataTypeSlice ret_types)166 explicit CallFrameBase(DataTypeSlice ret_types)
167 : ret_types_(ret_types), retvals_(ret_types.size()) {}
168
169 // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)170 Status ConsumeRetvals(std::vector<Tensor>* retvals) {
171 retvals->reserve(retvals_.size());
172 int i = 0;
173 for (auto&& val : retvals_) {
174 if (!val) {
175 return errors::Internal("No return value for index ", i, ".");
176 }
177 retvals->emplace_back(std::move(val.value()));
178 ++i;
179 }
180 return Status::OK();
181 }
182
num_retvals() const183 size_t num_retvals() const override { return retvals_.size(); }
184
185 // Callee methods.
SetRetval(int index,const Tensor & val)186 Status SetRetval(int index, const Tensor& val) override {
187 if (index < retvals_.size() && val.dtype() == ret_types_[index] &&
188 !retvals_[index]) {
189 retvals_[index] = val;
190 return Status::OK();
191 } else if (index >= retvals_.size()) {
192 return errors::InvalidArgument("Return value ", index,
193 " is out of range.");
194 } else if (val.dtype() != ret_types_[index]) {
195 return errors::InvalidArgument("Expected type ",
196 DataTypeString(ret_types_[index]),
197 " for return value ", index, " but got ",
198 DataTypeString(val.dtype()), ".");
199 } else {
200 return errors::Internal("Attempted to set return value ", index,
201 " more than once.");
202 }
203 }
204
205 private:
206 DataTypeSlice ret_types_;
207 std::vector<gtl::optional<Tensor>> retvals_;
208 TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
209 };
210
211 class OwnedArgsCallFrame : public CallFrameBase {
212 public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)213 OwnedArgsCallFrame(std::vector<Tensor>&& args,
214 const std::vector<Tensor>* captured_inputs,
215 DataTypeSlice ret_types)
216 : CallFrameBase(ret_types),
217 args_(std::move(args)),
218 captured_inputs_(captured_inputs) {}
219
num_args() const220 size_t num_args() const override {
221 return args_.size() + captured_inputs_->size();
222 }
223
224 // Callee methods.
GetArg(int index,Tensor * val) const225 Status GetArg(int index, Tensor* val) const override {
226 if (index < args_.size() && args_[index].IsInitialized()) {
227 // TODO(mrry): Consider making `CallFrameInterface::GetArg` non-const in
228 // order to be able to `std::move(args_[index])` into `*val`.
229 *val = args_[index];
230 return Status::OK();
231 } else if (index < args_.size() + captured_inputs_->size()) {
232 *val = (*captured_inputs_)[index - args_.size()];
233 return Status::OK();
234 } else if (index >= args_.size() + captured_inputs_->size()) {
235 return errors::InvalidArgument("Argument ", index, " is out of range.");
236 } else {
237 return errors::Internal("Attempted to get argument ", index,
238 " more than once.");
239 }
240 }
241
242 private:
243 std::vector<Tensor> args_;
244 const std::vector<Tensor>* const captured_inputs_; // Not owned.
245 };
246
247 class BorrowedArgsCallFrame : public CallFrameBase {
248 public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)249 BorrowedArgsCallFrame(const std::vector<Tensor>& args,
250 const std::vector<Tensor>* captured_inputs,
251 DataTypeSlice ret_types)
252 : CallFrameBase(ret_types),
253 args_(args),
254 captured_inputs_(captured_inputs) {}
255
num_args() const256 size_t num_args() const override {
257 return args_.size() + captured_inputs_->size();
258 }
259
260 // Callee methods.
GetArg(int index,Tensor * val) const261 Status GetArg(int index, Tensor* val) const override {
262 if (index < args_.size() && args_[index].IsInitialized()) {
263 *val = args_[index];
264 return Status::OK();
265 } else if (index < args_.size() + captured_inputs_->size()) {
266 *val = (*captured_inputs_)[index - args_.size()];
267 return Status::OK();
268 } else if (index >= args_.size() + captured_inputs_->size()) {
269 return errors::InvalidArgument("Argument ", index, " is out of range.");
270 } else {
271 return errors::Internal("Attempted to get argument ", index,
272 " more than once.");
273 }
274 }
275
276 private:
277 const std::vector<Tensor>& args_; // Not owned.
278 const std::vector<Tensor>* const captured_inputs_; // Not owned.
279 };
280
281 } // namespace
282
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func)283 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
284 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
285 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
286 CapturedFunction* captured_func)
287 : lib_(lib),
288 f_handle_(f_handle),
289 ret_types_(std::move(ret_types)),
290 captured_runner_(std::move(runner)),
291 captured_func_(captured_func) {}
292
293 // NOTE: We don't release f_handle_ here and instead delegate the function
294 // handle releasing to the FunctionHandleCache. This is because in some cases
295 // (RepeatDatasetOp in particular), we want to keep the function state (e.g.
296 // random number generator) even after the Iterator is reset after going through
297 // one epoch.
~InstantiatedCapturedFunction()298 InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {}
299
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const300 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
301 std::vector<Tensor>&& args,
302 std::vector<Tensor>* rets) const {
303 FunctionLibraryRuntime::Options f_opts;
304 f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
305 ScopedStepContainer step_container(
306 f_opts.step_id, [this](const string& name) {
307 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
308 });
309 f_opts.step_container = &step_container;
310 f_opts.runner = ctx->runner();
311 if (lib_->device()->device_type() != DEVICE_CPU) {
312 f_opts.create_rendezvous = true;
313 }
314 // TODO(mrry): Add cancellation manager support to IteratorContext
315 // so that we can cancel running map functions. The local
316 // cancellation manager here is created so that we can run kernels
317 // (such as queue kernels) that depend on the non-nullness of
318 // `OpKernelContext::cancellation_manager()`, but additional effort
319 // will be required to plumb it through the `IteratorContext`.
320 CancellationManager c_mgr;
321 f_opts.cancellation_manager = &c_mgr;
322
323 OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
324 ret_types_);
325 Notification n;
326 Status s;
327 lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
328 s.Update(func_status);
329 n.Notify();
330 });
331 n.WaitForNotification();
332 TF_RETURN_IF_ERROR(s);
333 return frame.ConsumeRetvals(rets);
334 }
335
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets) const336 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
337 IteratorContext* ctx, const std::vector<Tensor>& args,
338 std::vector<Tensor>* rets) const {
339 FunctionLibraryRuntime::Options f_opts;
340 f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
341 ScopedStepContainer step_container(
342 f_opts.step_id, [this](const string& name) {
343 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
344 });
345 f_opts.step_container = &step_container;
346 f_opts.runner = ctx->runner();
347 if (lib_->device()->device_type() != DEVICE_CPU) {
348 f_opts.create_rendezvous = true;
349 }
350 // TODO(mrry): Add cancellation manager support to IteratorContext
351 // so that we can cancel running map functions. The local
352 // cancellation manager here is created so that we can run kernels
353 // (such as queue kernels) that depend on the non-nullness of
354 // `OpKernelContext::cancellation_manager()`, but additional effort
355 // will be required to plumb it through the `IteratorContext`.
356 CancellationManager c_mgr;
357 f_opts.cancellation_manager = &c_mgr;
358
359 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
360 ret_types_);
361 Notification n;
362 Status s;
363
364 lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
365 s.Update(func_status);
366 n.Notify();
367 });
368 n.WaitForNotification();
369 TF_RETURN_IF_ERROR(s);
370 return frame.ConsumeRetvals(rets);
371 }
372
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)373 Status InstantiatedCapturedFunction::RunInstantiated(
374 const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
375 FunctionLibraryRuntime::Options f_opts;
376 f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
377 ScopedStepContainer step_container(
378 f_opts.step_id, [this](const string& name) {
379 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
380 });
381 f_opts.step_container = &step_container;
382 f_opts.runner = &captured_runner_;
383 if (lib_->device()->device_type() != DEVICE_CPU) {
384 f_opts.create_rendezvous = true;
385 }
386 // TODO(mrry): Add cancellation manager support to IteratorContext
387 // so that we can cancel running map functions. The local
388 // cancellation manager here is created so that we can run kernels
389 // (such as queue kernels) that depend on the non-nullness of
390 // `OpKernelContext::cancellation_manager()`, but additional effort
391 // will be required to plumb it through the `IteratorContext`.
392 CancellationManager c_mgr;
393 f_opts.cancellation_manager = &c_mgr;
394
395 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
396 ret_types_);
397 Notification n;
398 Status s;
399
400 lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
401 s.Update(func_status);
402 n.Notify();
403 });
404 n.WaitForNotification();
405 TF_RETURN_IF_ERROR(s);
406 return frame.ConsumeRetvals(rets);
407 }
408
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const string & prefix) const409 void InstantiatedCapturedFunction::RunAsync(
410 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
411 FunctionLibraryRuntime::DoneCallback done, const string& prefix) const {
412 // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
413 // be deleted before `done` is called. Take care not to capture `ctx` in any
414 // code that may execute asynchronously in this function.
415 OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
416 std::move(args), &captured_func_->captured_inputs(), ret_types_);
417
418 FunctionLibraryRuntime::Options f_opts;
419 f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
420 ResourceMgr* resource_mgr = lib_->device()->resource_manager();
421 ScopedStepContainer* step_container = new ScopedStepContainer(
422 f_opts.step_id, [resource_mgr](const string& name) {
423 resource_mgr->Cleanup(name).IgnoreError();
424 });
425 f_opts.step_container = step_container;
426 f_opts.runner = ctx->runner();
427 if (lib_->device()->device_type() != DEVICE_CPU) {
428 f_opts.create_rendezvous = true;
429 }
430 // TODO(mrry): Add cancellation manager support to IteratorContext
431 // so that we can cancel running map functions. The local
432 // cancellation manager here is created so that we can run kernels
433 // (such as queue kernels) that depend on the non-nullness of
434 // `OpKernelContext::cancellation_manager()`, but additional effort
435 // will be required to plumb it through the `IteratorContext`.
436 CancellationManager* c_mgr = new CancellationManager();
437 f_opts.cancellation_manager = c_mgr;
438 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
439 if (ctx->model() || ctx->stats_aggregator()) {
440 stats_collector = absl::make_unique<SimpleStepStatsCollector>();
441 }
442 f_opts.stats_collector = stats_collector.get();
443
444 auto callback = std::bind(
445 [this, rets, step_container, c_mgr, frame](
446 const FunctionLibraryRuntime::DoneCallback& done,
447 const std::shared_ptr<model::Model>& model,
448 const std::shared_ptr<StatsAggregator>& stats_aggregator,
449 const string& prefix,
450 const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
451 // Begin unbound arguments.
452 Status s) {
453 delete step_container;
454 delete c_mgr;
455 if (s.ok()) {
456 s = frame->ConsumeRetvals(rets);
457 }
458 delete frame;
459 // TODO(shivaniagrawal): add the dataset name containing this function,
460 // make it dataset()->node_name() + captured_func_->func().name().
461 if (stats_aggregator) {
462 string prefix_with_func_name = strings::StrCat(
463 str_util::Split(prefix, "::", str_util::SkipEmpty()).back(),
464 "::", captured_func_->func().name());
465 stats_aggregator->AddToHistogram(
466 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
467 {static_cast<float>(stats_collector->processing_time())});
468 }
469 if (model) {
470 model->AddProcessingTime(prefix, stats_collector->processing_time());
471 model->RecordStart(prefix, false /* stop_output */);
472 }
473 done(s);
474 if (model) {
475 model->RecordStop(prefix, false /* start_output */);
476 }
477 },
478 std::move(done), ctx->model(), ctx->stats_aggregator(), prefix,
479 std::move(stats_collector), std::placeholders::_1);
480
481 lib_->Run(f_opts, f_handle_, frame, std::move(callback));
482 }
483
CapturedFunction(const NameAttrList & func,std::vector<Tensor> captured_inputs,bool use_inter_op_parallelism)484 CapturedFunction::CapturedFunction(const NameAttrList& func,
485 std::vector<Tensor> captured_inputs,
486 bool use_inter_op_parallelism)
487 : func_(func),
488 captured_inputs_(std::move(captured_inputs)),
489 use_inter_op_parallelism_(use_inter_op_parallelism) {}
490
491 } // namespace data
492 } // namespace tensorflow
493