• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/types.h"
16 #define EIGEN_USE_THREADS
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #endif
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 
34 namespace tensorflow {
35 typedef Eigen::GpuDevice GPUDevice;
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef FunctionLibraryRuntime::Handle FHandle;
38 typedef std::vector<Tensor> TensorVec;
39 
40 namespace {
41 
42 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)43 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
44                    FunctionLibraryRuntime::Handle* handle) {
45   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
46 }
47 
Instantiate(OpKernelContext * ctx,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)48 Status Instantiate(OpKernelContext* ctx, const NameAttrList& func,
49                    FunctionLibraryRuntime::Handle* handle) {
50   FunctionLibraryRuntime::InstantiateOptions opts;
51   opts.executor_type = ctx->executor_type();
52   return ctx->function_library()->Instantiate(
53       func.name(), AttrSlice(&func.attr()), opts, handle);
54 }
55 
56 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
ToBool(gtl::ArraySlice<Tensor> t,bool * v)57 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
58   if (t.size() != 1) {
59     return errors::InvalidArgument(
60         "Expected a single scalar which can be converted to a boolean, got ",
61         t.size(), " tensors.");
62   }
63   if (TensorShapeUtils::IsScalar(t[0].shape())) {
64     switch (t[0].dtype()) {
65 #define CASE(T)                   \
66   case DataTypeToEnum<T>::value:  \
67     *v = t[0].scalar<T>()() != 0; \
68     break;
69 
70       CASE(float);
71       CASE(double);
72       CASE(int32);
73       CASE(uint8);
74       CASE(int16);
75       CASE(int8);
76       CASE(int64);
77 #undef CASE
78       case DT_BOOL:
79         *v = t[0].scalar<bool>()();
80         break;
81       case DT_STRING:
82         *v = !t[0].scalar<tstring>()().empty();
83         break;
84       default:
85         return errors::InvalidArgument(DataTypeString(t[0].dtype()),
86                                        " cannot be converted to a boolean");
87     }
88   } else {
89     *v = t[0].NumElements() > 0;
90   }
91   return Status::OK();
92 }
93 
94 // Sets "rets" to be the output of "ctx". Validates rets' types based
95 // on "kernel".
SetOutputs(const OpKernel * kernel,OpKernelContext * ctx,gtl::ArraySlice<Tensor> rets)96 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
97                   gtl::ArraySlice<Tensor> rets) {
98   if (rets.size() != ctx->num_outputs()) {
99     return errors::Internal("Expect to produce ", ctx->num_outputs(),
100                             " tensors, but only get ", rets.size());
101   }
102   for (int i = 0; i < rets.size(); ++i) {
103     if (rets[i].dtype() != kernel->output_type(i)) {
104       return errors::Internal("Expect ", i, "-th output is of type ",
105                               DataTypeString(kernel->output_type(i)),
106                               " but get ", DataTypeString(rets[i].dtype()));
107     }
108     ctx->set_output(i, rets[i]);
109   }
110   return Status::OK();
111 }
112 
SetRunOptions(OpKernelContext * ctx,FunctionLibraryRuntime::Options * opts,bool always_collect_stats)113 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
114                    bool always_collect_stats) {
115   opts->rendezvous = ctx->rendezvous();
116   opts->cancellation_manager = ctx->cancellation_manager();
117   opts->collective_executor = ctx->collective_executor();
118   if (always_collect_stats) {
119     opts->stats_collector = ctx->stats_collector();
120   }
121   opts->runner = ctx->runner();
122   opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
123   opts->step_container = ctx->step_container();
124 }
125 
126 class IfOp : public AsyncOpKernel {
127  public:
IfOp(OpKernelConstruction * ctx)128   explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
129     auto lib = ctx->function_library();
130     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
131     OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
132     OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
133   }
134 
~IfOp()135   ~IfOp() override {}
136 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)137   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
138     FHandle then_handle;
139     FHandle else_handle;
140     OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
141                          done);
142     bool cond;
143     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
144     (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
145   }
146 
147  private:
148   NameAttrList then_func_;
149   NameAttrList else_func_;
150 
151   mutex mu_;
152   std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
153       handles_ ABSL_GUARDED_BY(mu_);
154 
155   class State {
156    public:
State(IfOp * kernel,OpKernelContext * ctx,bool cond,FHandle then_handle,FHandle else_handle,DoneCallback done)157     State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
158           FHandle else_handle, DoneCallback done)
159         : kernel_(kernel),
160           ctx_(ctx),
161           cond_(cond),
162           then_handle_(then_handle),
163           else_handle_(else_handle),
164           done_(std::move(done)),
165           lib_(CHECK_NOTNULL(ctx_->function_library())) {
166       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
167       for (int i = 1; i < ctx_->num_inputs(); ++i) {
168         args_.push_back(ctx_->input(i));
169       }
170     }
171 
~State()172     ~State() {}
173 
Start()174     void Start() {
175       FHandle handle = cond_ ? then_handle_ : else_handle_;
176       rets_.clear();
177       profiler::TraceMe trace_me("IfOp");
178       lib_->Run(
179           // Evaluate one of the branch.
180           opts_, handle, args_, &rets_,
181           // Done callback
182           [this](Status s) {
183             if (s.ok()) {
184               s = SetOutputs(kernel_, ctx_, rets_);
185             }
186             ctx_->SetStatus(s);
187             DoneCallback captured_done(std::move(done_));
188             delete this;
189             captured_done();
190           });
191     }
192 
193    private:
194     IfOp* const kernel_;
195     OpKernelContext* const ctx_;
196     const bool cond_;
197     FHandle then_handle_;
198     FHandle else_handle_;
199     DoneCallback done_;
200     FunctionLibraryRuntime* const lib_;
201     FunctionLibraryRuntime::Options opts_;
202     TensorVec args_;
203     TensorVec rets_;
204   };
205 
GetHandles(OpKernelContext * ctx,FHandle * then_handle,FHandle * else_handle)206   Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
207                     FHandle* else_handle) {
208     // TODO(b/37549631): Because this op has `SetIsStateful()` in its
209     // op registration, this kernel may be shared by multiple
210     // subgraphs, which have different associated
211     // `FunctionLibraryRuntime` objects and hence different `FHandle`
212     // namespaces. We currently work around this by caching the map
213     // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
214     // functions this op uses.
215     auto lib = ctx->function_library();
216     if (lib == nullptr) return errors::Internal("No function library");
217     *then_handle = kInvalidHandle;
218     *else_handle = kInvalidHandle;
219     {
220       tf_shared_lock l(mu_);
221       const auto iter = handles_.find(lib);
222       if (TF_PREDICT_TRUE(iter != handles_.end())) {
223         *then_handle = iter->second.first;
224         *else_handle = iter->second.second;
225       }
226     }
227     if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
228       mutex_lock l(mu_);
229       const auto iter = handles_.find(lib);
230       if (TF_PREDICT_TRUE(iter != handles_.end())) {
231         *then_handle = iter->second.first;
232         *else_handle = iter->second.second;
233       } else {
234         TF_RETURN_IF_ERROR(Instantiate(ctx, then_func_, then_handle));
235         TF_RETURN_IF_ERROR(Instantiate(ctx, else_func_, else_handle));
236         handles_[lib] = {*then_handle, *else_handle};
237       }
238     }
239     return Status::OK();
240   }
241 };
242 
243 class CaseOp : public AsyncOpKernel {
244  public:
CaseOp(OpKernelConstruction * ctx)245   explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
246     auto lib = ctx->function_library();
247     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
248     OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
249   }
250 
~CaseOp()251   ~CaseOp() override {}
252 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)253   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
254     auto lib = ctx->function_library();
255     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
256                       errors::Internal("No function library"), done);
257 
258     // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
259     // registration, this kernel may be shared by multiple subgraphs, which have
260     // different associated `FunctionLibraryRuntime` objects and hence different
261     // `FHandle` namespaces. So we must call Instantiate() to make sure we get
262     // the correct function handles with respect to `lib`. Note the underlying
263     // `lib->Instantiate()` caches the created function handles, so calling
264     // `Instantiate()` repeatedly on the same `lib` and function is cheap.
265     std::vector<FHandle> branch_handles(branch_funcs_.size());
266     for (int i = 0; i < branch_funcs_.size(); i++) {
267       OP_REQUIRES_OK_ASYNC(
268           ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
269     }
270 
271     const Tensor& branch_index = ctx->input(0);
272     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
273                       errors::InvalidArgument("branch_index must be scalar"),
274                       done);
275     int32 branch = branch_index.scalar<int32>()();
276     (new State(this, ctx, branch, branch_handles, done))->Start();
277   }
278 
279  private:
280   std::vector<NameAttrList> branch_funcs_;
281 
282   class State {
283    public:
State(CaseOp * kernel,OpKernelContext * ctx,int branch,std::vector<FHandle> branch_handles,DoneCallback done)284     State(CaseOp* kernel, OpKernelContext* ctx, int branch,
285           std::vector<FHandle> branch_handles, DoneCallback done)
286         : kernel_(kernel),
287           ctx_(ctx),
288           branch_(branch),
289           branch_handles_(branch_handles),
290           done_(std::move(done)),
291           lib_(CHECK_NOTNULL(ctx_->function_library())) {
292       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
293       for (int i = 1; i < ctx_->num_inputs(); ++i) {
294         args_.push_back(ctx_->input(i));
295       }
296     }
297 
~State()298     ~State() {}
299 
Start()300     void Start() {
301       int branch = branch_;
302       // The last branch is the default branch.
303       if (branch < 0 || branch >= branch_handles_.size()) {
304         branch = branch_handles_.size() - 1;
305       }
306       rets_.clear();
307       profiler::TraceMe trace_me("CaseOp");
308       lib_->Run(
309           // Evaluate one of the branch.
310           opts_, branch_handles_[branch], args_, &rets_,
311           // Done callback
312           [this](Status s) {
313             if (s.ok()) {
314               s = SetOutputs(kernel_, ctx_, rets_);
315             }
316             ctx_->SetStatus(s);
317             DoneCallback captured_done(std::move(done_));
318             delete this;
319             captured_done();
320           });
321     }
322 
323    private:
324     CaseOp* const kernel_;
325     OpKernelContext* const ctx_;
326     const int branch_;
327     std::vector<FHandle> branch_handles_;
328     DoneCallback done_;
329     FunctionLibraryRuntime* const lib_;
330     FunctionLibraryRuntime::Options opts_;
331     TensorVec args_;
332     TensorVec rets_;
333   };
334 };
335 
336 // TODO(drpng): remove this.
337 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
338 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
339                         IfOp);
340 
341 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
342 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
343 
344 REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
345 REGISTER_KERNEL_BUILDER(
346     Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
347 REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
348 REGISTER_KERNEL_BUILDER(
349     Name("StatelessCase").Device(DEVICE_GPU).HostMemory("branch_index"),
350     CaseOp);
351 
352 REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
353 REGISTER_KERNEL_BUILDER(
354     Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
355 
356 class WhileOp : public AsyncOpKernel {
357  public:
WhileOp(OpKernelConstruction * ctx)358   explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
359     OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
360     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
361   }
362 
~WhileOp()363   ~WhileOp() override {}
364 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)365   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
366     if (ctx->run_all_kernels_inline()) {
367       // Use the non-callback-based implementation when kernels (and function
368       // callbacks) execute inline to avoid stack overflow.
369       OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
370     } else {
371       FHandle cond_handle;
372       FHandle body_handle;
373       OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
374                            done);
375       (new State(this, ctx, cond_handle, body_handle, done))->Start();
376     }
377   }
378 
Compute(OpKernelContext * ctx)379   void Compute(OpKernelContext* ctx) override {
380     // Use the non-callback-based implementation when the synchronous Compute()
381     // method is invoked, because the caller is explicitly donating a thread.
382     Status s = DoComputeSync(ctx);
383     // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
384     // still an AsyncOpKernel, and there is a run-time check to avoid calling
385     // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
386     // the event of an error).
387     if (TF_PREDICT_FALSE(!s.ok())) {
388       ctx->SetStatus(s);
389     }
390   }
391 
392  private:
393   NameAttrList cond_func_;
394   NameAttrList body_func_;
395 
396   mutex mu_;
397   std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
398       handles_ ABSL_GUARDED_BY(mu_);
399 
CondResultToBool(OpKernelContext * ctx,const FunctionLibraryRuntime::Options & opts,const Tensor & cond_t,bool * out_result)400   static Status CondResultToBool(OpKernelContext* ctx,
401                                  const FunctionLibraryRuntime::Options& opts,
402                                  const Tensor& cond_t, bool* out_result) {
403 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
404     const DeviceBase::GpuDeviceInfo* gpu_device_info =
405         ctx->device()->tensorflow_gpu_device_info();
406     const bool is_hostmem_dtype =
407         cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
408     if (!is_hostmem_dtype && gpu_device_info &&
409         (opts.rets_alloc_attrs.empty() ||
410          !opts.rets_alloc_attrs[0].on_host())) {
411       // Copy the ret value to host if it's allocated on device.
412       Device* device = down_cast<Device*>(ctx->device());
413       DeviceContext* device_ctx = ctx->op_device_context();
414       Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
415       TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
416           &cond_t, /*tensor_name=*/"", device, &host_cond_t));
417       return ToBool({host_cond_t}, out_result);
418     }
419 #endif
420     return ToBool({cond_t}, out_result);
421   }
422 
423   // The initial loop variable args are the inputs to the kernel.
424   //
425   // We attempt to forward the input so that it can be consumed inside the
426   // body function (and participate in buffer forwarding, etc.).
GetArgsFromContext(OpKernelContext * ctx,std::vector<Tensor> * out_args,DataTypeVector * out_var_types)427   static void GetArgsFromContext(OpKernelContext* ctx,
428                                  std::vector<Tensor>* out_args,
429                                  DataTypeVector* out_var_types) {
430     const int num_loop_vars = ctx->num_inputs();
431     out_args->reserve(num_loop_vars);
432     out_var_types->resize(num_loop_vars);
433     for (int i = 0; i < num_loop_vars; ++i) {
434       const Tensor& input = ctx->input(i);
435       (*out_var_types)[i] = input.dtype();
436       std::unique_ptr<Tensor> maybe_forwarded_input = ctx->forward_input(
437           i, /* output_index= */ OpKernelContext::Params::kNoReservation,
438           input.dtype(), input.shape(), ctx->input_memory_type(i),
439           ctx->input_alloc_attr(i));
440       if (maybe_forwarded_input) {
441         out_args->push_back(std::move(*maybe_forwarded_input));
442       } else {
443         out_args->push_back(input);
444       }
445     }
446   }
447 
448   class BodyFuncCallFrame : public CallFrameInterface {
449    public:
BodyFuncCallFrame(std::vector<Tensor> * args,std::vector<Tensor> * retvals,DataTypeSlice ret_types)450     BodyFuncCallFrame(std::vector<Tensor>* args, std::vector<Tensor>* retvals,
451                       DataTypeSlice ret_types)
452         : args_(args), retvals_(retvals), ret_types_(ret_types) {}
453 
num_args() const454     size_t num_args() const override { return args_->size(); }
num_retvals() const455     size_t num_retvals() const override { return retvals_->size(); }
456 
GetArg(int index,const Tensor ** val)457     Status GetArg(int index, const Tensor** val) override {
458       if (index < args_->size()) {
459         *val = &(*args_)[index];
460         return Status::OK();
461       } else {
462         return errors::InvalidArgument("Argument ", index, " is out of range.");
463       }
464     }
465 
ConsumeArg(int index,Tensor * val)466     void ConsumeArg(int index, Tensor* val) override {
467       DCHECK_GE(index, 0);
468       DCHECK_LT(index, args_->size());
469       *val = std::move((*args_)[index]);
470     }
CanConsumeArg(int index) const471     bool CanConsumeArg(int index) const override {
472       return index >= 0 && index < args_->size();
473     }
474 
SetRetval(int index,const Tensor & val)475     Status SetRetval(int index, const Tensor& val) override {
476       if (TF_PREDICT_FALSE(index < 0)) {
477         return errors::InvalidArgument(
478             "Expected non-negative return value index, but got: ", index, ".");
479       } else if (TF_PREDICT_FALSE(index >= retvals_->size())) {
480         return errors::InvalidArgument("While loop body returned ", index + 1,
481                                        " arguments. Expected: ", num_retvals(),
482                                        ".");
483       } else if (TF_PREDICT_FALSE(val.dtype() != ret_types_[index])) {
484         return errors::InvalidArgument("Expected type ",
485                                        DataTypeString(ret_types_[index]),
486                                        " for return value ", index, " but got ",
487                                        DataTypeString(val.dtype()), ".");
488       }
489       (*retvals_)[index] = val;
490       return Status::OK();
491     }
492 
493    private:
494     std::vector<Tensor>* const args_;     // Not owned.
495     std::vector<Tensor>* const retvals_;  // Not owned.
496     DataTypeSlice ret_types_;
497 
498     TF_DISALLOW_COPY_AND_ASSIGN(BodyFuncCallFrame);
499   };
500 
501   class State {
502    public:
State(WhileOp * kernel,OpKernelContext * ctx,FHandle cond_handle,FHandle body_handle,DoneCallback done)503     State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
504           FHandle body_handle, DoneCallback done)
505         : kernel_(kernel),
506           ctx_(ctx),
507           cond_handle_(cond_handle),
508           body_handle_(body_handle),
509           done_(std::move(done)),
510           lib_(CHECK_NOTNULL(ctx_->function_library())) {
511       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
512       GetArgsFromContext(ctx, &args_, &loop_var_types_);
513       body_frame_ =
514           absl::make_unique<BodyFuncCallFrame>(&args_, &rets_, loop_var_types_);
515     }
516 
~State()517     ~State() {}
518 
Start()519     void Start() { EvalCond(); }
520 
521    private:
522     WhileOp* const kernel_;
523     OpKernelContext* const ctx_;
524     const FHandle cond_handle_;
525     const FHandle body_handle_;
526     const DoneCallback done_;
527     FunctionLibraryRuntime* const lib_;
528     FunctionLibraryRuntime::Options opts_;
529     TensorVec args_;
530     TensorVec rets_;
531     DataTypeVector loop_var_types_;
532     std::unique_ptr<BodyFuncCallFrame> body_frame_;
533 
EvalCond()534     void EvalCond() {
535       profiler::TraceMe trace_me("WhileOp-EvalCond");
536       lib_->Run(
537           // Evaluate the condition.
538           opts_, cond_handle_, args_, &rets_,
539           // Done cb.
540           [this](const Status& s) {
541             if (!s.ok()) {
542               return Finish(s);
543             }
544             StartBody();
545           });
546     }
547 
StartBody()548     void StartBody() {
549       Status s;
550       if (rets_.size() != 1) {
551         s = errors::InvalidArgument(
552             "Expected a single scalar return value from WhileOp cond, got ",
553             rets_.size(), " tensors.");
554         return Finish(s);
555       }
556 
557       if (!s.ok()) {
558         return Finish(s);
559       }
560       bool cond;
561       s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
562       if (!s.ok()) {
563         return Finish(s);
564       }
565 
566       if (!cond) {
567         return Finish(Status::OK());
568       }
569       rets_.clear();
570       rets_.resize(args_.size());
571       profiler::TraceMe trace_me("WhileOp-StartBody");
572       lib_->Run(
573           // Evaluate the body.
574           opts_, body_handle_, body_frame_.get(),
575           // Done callback
576           [this](const Status& s) {
577             if (!s.ok()) {
578               return Finish(s);
579             }
580             if (args_.size() != rets_.size()) {
581               return Finish(errors::InvalidArgument(
582                   "While loop body returned ", rets_.size(),
583                   " arguments. Expected: ", args_.size()));
584             }
585             args_.clear();
586             using std::swap;
587             swap(args_, rets_);
588             EvalCond();
589           });
590     }
591 
Finish(Status s)592     void Finish(Status s) {
593       if (s.ok()) {
594         s = SetOutputs(kernel_, ctx_, args_);
595       }
596       ctx_->SetStatus(s);
597       done_();
598       delete this;
599     }
600   };
601 
DoComputeSync(OpKernelContext * ctx)602   Status DoComputeSync(OpKernelContext* ctx) {
603     FHandle cond_handle;
604     FHandle body_handle;
605     TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
606     auto lib = ctx->function_library();
607     FunctionLibraryRuntime::Options opts;
608     SetRunOptions(ctx, &opts, false /* always_collect_stats */);
609 
610     // Pre-allocate argument and return value vectors for the cond and body
611     // functions.
612     std::vector<Tensor> args;
613     const int num_loop_vars = ctx->num_inputs();
614     DataTypeVector loop_var_types(num_loop_vars);
615     GetArgsFromContext(ctx, &args, &loop_var_types);
616     std::vector<Tensor> cond_rets;
617     cond_rets.reserve(1);
618     std::vector<Tensor> body_rets;
619     body_rets.reserve(num_loop_vars);
620 
621     // Implement the logic of the while loop as a single C++ do-while loop that
622     // executes the cond and body functions synchronously.
623     do {
624       // Evaluate the cond function on the current loop variables.
625       {
626         profiler::TraceMe trace_me("WhileOp-EvalCond");
627         TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
628       }
629       if (cond_rets.size() != 1) {
630         return errors::InvalidArgument(
631             "Expected a single scalar return value from WhileOp cond, got ",
632             cond_rets.size(), " tensors.");
633       }
634 
635       // If the cond function evaluates to false, we are done: output the
636       // current loop variables.
637       bool cond_result;
638       TF_RETURN_IF_ERROR(
639           CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
640       if (!cond_result) {
641         return SetOutputs(this, ctx, args);
642       }
643 
644       // Evaluate the body function on the current loop variables, to get an
645       // updated vector of loop variables.
646       {
647         profiler::TraceMe trace_me("WhileOp-StartBody");
648         body_rets.resize(num_loop_vars);
649         BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
650         TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
651       }
652       std::swap(body_rets, args);
653       body_rets.clear();
654     } while (true);
655   }
656 
GetHandles(OpKernelContext * ctx,FHandle * cond_handle,FHandle * body_handle)657   Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
658                     FHandle* body_handle) {
659     // TODO(b/37549631): Because this op has `SetIsStateful()` in its
660     // op registration, this kernel may be shared by multiple
661     // subgraphs, which have different associated
662     // `FunctionLibraryRuntime` objects and hence different `FHandle`
663     // namespaces. We currently work around this by caching the map
664     // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
665     // functions this op uses.
666     auto lib = ctx->function_library();
667     if (lib == nullptr) return errors::Internal("No function library");
668     *cond_handle = kInvalidHandle;
669     *body_handle = kInvalidHandle;
670     {
671       tf_shared_lock l(mu_);
672       const auto iter = handles_.find(lib);
673       if (TF_PREDICT_TRUE(iter != handles_.end())) {
674         *cond_handle = iter->second.first;
675         *body_handle = iter->second.second;
676       }
677     }
678     if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
679       mutex_lock l(mu_);
680       const auto iter = handles_.find(lib);
681       if (TF_PREDICT_TRUE(iter != handles_.end())) {
682         *cond_handle = iter->second.first;
683         *body_handle = iter->second.second;
684       } else {
685         TF_RETURN_IF_ERROR(Instantiate(ctx, cond_func_, cond_handle));
686         TF_RETURN_IF_ERROR(Instantiate(ctx, body_func_, body_handle));
687         handles_[lib] = {*cond_handle, *body_handle};
688       }
689     }
690     return Status::OK();
691   }
692 };
693 // TODO(drpng): remove these.
694 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
695 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
696 
697 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
698 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
699 
700 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
701 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
702 
703 class ToBoolOp : public OpKernel {
704  public:
ToBoolOp(OpKernelConstruction * ctx)705   explicit ToBoolOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)706   void Compute(OpKernelContext* ctx) override {
707     bool b;
708     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &b));
709     Tensor* out;
710     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
711     out->scalar<bool>()() = b;
712   }
713 };
714 
715 REGISTER_KERNEL_BUILDER(Name("ToBool").Device(DEVICE_CPU), ToBoolOp);
716 
GetScalar(OpKernelContext * ctx,int index,int32 * value,const char * label)717 Status GetScalar(OpKernelContext* ctx, int index, int32* value,
718                  const char* label) {
719   Tensor t = ctx->input(index);
720   if (!TensorShapeUtils::IsScalar(t.shape())) {
721     return errors::InvalidArgument(label, " must be a scalar, but ",
722                                    t.shape().DebugString());
723   }
724   *value = t.scalar<int32>()();
725   return Status::OK();
726 }
727 
728 class ForOp : public AsyncOpKernel {
729  public:
ForOp(OpKernelConstruction * ctx)730   explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
731     auto lib = ctx->function_library();
732     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
733     const NameAttrList* func;
734     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
735     OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
736   }
737 
~ForOp()738   ~ForOp() override {}
739 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)740   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
741     (new State(this, ctx, done))->Start();
742   }
743 
744  private:
745   FHandle body_handle_;
746 
747   class State {
748    public:
State(ForOp * kernel,OpKernelContext * ctx,DoneCallback done)749     State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
750         : kernel_(kernel),
751           ctx_(ctx),
752           done_(std::move(done)),
753           lib_(CHECK_NOTNULL(ctx_->function_library())),
754           args_(1 + ctx_->num_inputs() - 3) {
755       args_[0] = Tensor(DT_INT32, {});
756       iter_ = &args_[0].scalar<int32>()();
757 
758       const int32 num_loop_inputs = ctx_->num_inputs() - 3;
759       rets_.reserve(num_loop_inputs);
760       for (int i = 0; i < num_loop_inputs; ++i) {
761         rets_.push_back(ctx_->input(3 + i));
762       }
763     }
764 
~State()765     ~State() {}
766 
Start()767     void Start() {
768       Status s = StartLoop();
769       if (!s.ok()) Finish(s);
770     }
771 
772    private:
773     ForOp* const kernel_;
774     OpKernelContext* const ctx_;
775     const DoneCallback done_;
776     FunctionLibraryRuntime* const lib_;
777     FunctionLibraryRuntime::Options opts_;
778     TensorVec args_;
779     TensorVec rets_;
780 
781     int32* iter_;  // points to args_[0].
782     int32 limit_;
783     int32 delta_;
784 
785     // If an error e is returned, caller must call Finish(e).
786     // If OK is returned, the async loop execution has been started.
StartLoop()787     Status StartLoop() {
788       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
789 
790       TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
791       TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
792       TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
793 
794       if ((delta_ > 0 && *iter_ <= limit_) ||
795           (delta_ < 0 && *iter_ >= limit_) ||
796           (delta_ == 0 && *iter_ == limit_)) {
797         RunNext();
798         return Status::OK();
799       } else {
800         return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
801                                        " ", limit_, " ", delta_);
802       }
803     }
804 
RunNext()805     void RunNext() {
806       bool done_loop;
807       if (delta_ > 0) {
808         done_loop = *iter_ >= limit_;
809       } else {
810         done_loop = *iter_ <= limit_;
811       }
812       if (done_loop) {
813         Finish(Status::OK());
814         return;
815       }
816 
817       if (rets_.size() >= args_.size()) {
818         Finish(errors::InvalidArgument(
819             "For loop body returned ", rets_.size(),
820             " arguments. Expected: ", args_.size() - 1));
821         return;
822       }
823       for (int i = 0; i < rets_.size(); ++i) {
824         args_[1 + i] = std::move(rets_[i]);
825       }
826       rets_.clear();
827       profiler::TraceMe trace_me("ForOp");
828       lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
829                 [this](const Status& s) {
830                   if (s.ok()) {
831                     *iter_ += delta_;
832                     RunNext();
833                   } else {
834                     Finish(s);
835                   }
836                 });
837     }
838 
Finish(Status s)839     void Finish(Status s) {
840       if (s.ok()) {
841         s = SetOutputs(kernel_, ctx_, rets_);
842       }
843       ctx_->SetStatus(s);
844       done_();
845       delete this;
846     }
847   };
848 };
849 
850 REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
851 REGISTER_KERNEL_BUILDER(Name("For")
852                             .Device(DEVICE_GPU)
853                             .HostMemory("start")
854                             .HostMemory("limit")
855                             .HostMemory("delta"),
856                         ForOp);
857 
858 // FakeParamOp allocates a tensor with a shape conforming to the expected
859 // output. This is necessary if the value will be stored in a while_loop's
860 // TensorList. The output is otherwise not expected to be consumed by anything
861 // else.
862 class FakeParamOp : public OpKernel {
863  public:
FakeParamOp(OpKernelConstruction * context)864   explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
865     DataType dtype;
866     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
867 
868     // Set shape to the specified shape, setting unknown dimensions to empty.
869     // If the specified shape is unknown, leave as an empty shape.
870     TensorShape shape;
871     PartialTensorShape partial_shape;
872     OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
873     if (!partial_shape.unknown_rank()) {
874       for (int64 d : partial_shape.dim_sizes()) {
875         shape.AddDim(d == -1 ? 0 : d);
876       }
877     }
878 
879     // Create a persistent tensor that we can repeatedly return to save memory.
880     // TODO(b/119612758): add optimization to prevent sending this across
881     // devices on each Compute() call.
882     OP_REQUIRES_OK(context, context->allocate_persistent(
883                                 dtype, shape, &value_handle_, nullptr));
884   }
885 
Compute(OpKernelContext * context)886   void Compute(OpKernelContext* context) override {
887     context->set_output(0, *value_handle_.AccessTensor(context));
888   }
889 
890  private:
891   PersistentTensor value_handle_;
892 };
893 
894 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
895 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
896 
897 // DeviceIndexOP returns the current device index.
898 class DeviceIndexOp : public OpKernel {
899  public:
DeviceIndexOp(OpKernelConstruction * ctx)900   explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
901     OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
902   }
903 
Compute(OpKernelContext * ctx)904   void Compute(OpKernelContext* ctx) override {
905     Tensor* device_name_t;
906     OP_REQUIRES_OK(ctx,
907                    ctx->allocate_output(0, TensorShape({}), &device_name_t));
908     DeviceNameUtils::ParsedName parsed_name;
909     int index = device_names_.size();
910     if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
911         parsed_name.has_type) {
912       auto it = absl::c_find(device_names_, parsed_name.type);
913       if (it != device_names_.end()) {
914         index = it - device_names_.begin();
915       }
916     }
917     device_name_t->scalar<int32>()() = index;
918   }
919 
920  private:
921   PersistentTensor value_handle_;
922   std::vector<string> device_names_;
923 };
924 
925 REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
926 REGISTER_KERNEL_BUILDER(
927     Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
928 
929 }  // namespace
930 }  // namespace tensorflow
931