• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
18 #if GOOGLE_CUDA
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/device_base.h"
21 #endif
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 
29 namespace tensorflow {
30 typedef Eigen::GpuDevice GPUDevice;
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef FunctionLibraryRuntime::Handle FHandle;
33 typedef std::vector<Tensor> TensorVec;
34 
35 namespace {
36 
37 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)38 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
39                    FunctionLibraryRuntime::Handle* handle) {
40   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
41 }
42 
43 template <typename To, typename From>  // use like this: down_cast<T*>(foo);
down_cast(From * f)44 inline To down_cast(From* f) {         // so we only accept pointers
45   static_assert(
46       (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
47       "target type not derived from source type");
48 
49   // We skip the assert and hence the dynamic_cast if RTTI is disabled.
50 #if !defined(__GNUC__) || defined(__GXX_RTTI)
51   // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
52   assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
53 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
54 
55   return static_cast<To>(f);
56 }
57 
58 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
ToBool(gtl::ArraySlice<Tensor> t,bool * v)59 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
60   if (t.size() != 1) {
61     return errors::InvalidArgument(
62         "Expected a single scalar which can be converted to a boolean, got ",
63         t.size(), " tensors.");
64   }
65   if (TensorShapeUtils::IsScalar(t[0].shape())) {
66     switch (t[0].dtype()) {
67 #define CASE(T)                   \
68   case DataTypeToEnum<T>::value:  \
69     *v = t[0].scalar<T>()() != 0; \
70     break;
71 
72       CASE(float);
73       CASE(double);
74       CASE(int32);
75       CASE(uint8);
76       CASE(int16);
77       CASE(int8);
78       CASE(int64);
79 #undef CASE
80       case DT_BOOL:
81         *v = t[0].scalar<bool>()();
82         break;
83       case DT_STRING:
84         *v = !t[0].scalar<string>()().empty();
85         break;
86       default:
87         return errors::InvalidArgument(DataTypeString(t[0].dtype()),
88                                        " cannot be converted to a boolean");
89     }
90   } else {
91     *v = t[0].NumElements() > 0;
92   }
93   return Status::OK();
94 }
95 
96 // Sets "rets" to be the output of "ctx". Validates rets' types based
97 // on "kernel".
SetOutputs(const OpKernel * kernel,OpKernelContext * ctx,gtl::ArraySlice<Tensor> rets)98 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
99                   gtl::ArraySlice<Tensor> rets) {
100   if (rets.size() != ctx->num_outputs()) {
101     return errors::Internal("Expect to produce ", ctx->num_outputs(),
102                             " tensors, but only get ", rets.size());
103   }
104   for (int i = 0; i < rets.size(); ++i) {
105     if (rets[i].dtype() != kernel->output_type(i)) {
106       return errors::Internal("Expect ", i, "-th output is of type ",
107                               DataTypeString(kernel->output_type(i)),
108                               " but get ", DataTypeString(rets[i].dtype()));
109     }
110     ctx->set_output(i, rets[i]);
111   }
112   return Status::OK();
113 }
114 
SetRunOptions(OpKernelContext * ctx,FunctionLibraryRuntime::Options * opts,bool always_collect_stats)115 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
116                    bool always_collect_stats) {
117   opts->step_id = ctx->step_id();
118   opts->rendezvous = ctx->rendezvous();
119   opts->cancellation_manager = ctx->cancellation_manager();
120   if (always_collect_stats) {
121     opts->stats_collector = ctx->stats_collector();
122   }
123   opts->runner = ctx->runner();
124   opts->step_container = ctx->step_container();
125 }
126 
127 class IfOp : public AsyncOpKernel {
128  public:
IfOp(OpKernelConstruction * ctx)129   explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
130     auto lib = ctx->function_library();
131     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
132     OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
133     OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
134   }
135 
~IfOp()136   ~IfOp() override {}
137 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)138   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
139     auto lib = ctx->function_library();
140     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
141                       errors::Internal("No function library"), done);
142 
143     // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
144     // registration, this kernel may be shared by multiple subgraphs, which have
145     // different associated `FunctionLibraryRuntime` objects and hence different
146     // `FHandle` namespaces. So we must call Instantiate() to make sure we get
147     // the correct function handles with respect to `lib`. Note the underlying
148     // `lib->Instantiate()` caches the created function handles, so calling
149     // `Instantiate()` repeatedly on the same `lib` and function is cheap.
150     FHandle then_handle;
151     FHandle else_handle;
152     OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
153     OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
154 
155     bool cond;
156     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
157     (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
158   }
159 
160  private:
161   NameAttrList then_func_;
162   NameAttrList else_func_;
163 
164   class State {
165    public:
State(IfOp * kernel,OpKernelContext * ctx,bool cond,FHandle then_handle,FHandle else_handle,DoneCallback done)166     State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
167           FHandle else_handle, DoneCallback done)
168         : kernel_(kernel),
169           ctx_(ctx),
170           cond_(cond),
171           then_handle_(then_handle),
172           else_handle_(else_handle),
173           done_(std::move(done)),
174           lib_(CHECK_NOTNULL(ctx_->function_library())) {
175       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
176       for (int i = 1; i < ctx_->num_inputs(); ++i) {
177         args_.push_back(ctx_->input(i));
178       }
179     }
180 
~State()181     ~State() {}
182 
Start()183     void Start() {
184       FHandle handle = cond_ ? then_handle_ : else_handle_;
185       rets_.clear();
186       lib_->Run(
187           // Evaluate one of the branch.
188           opts_, handle, args_, &rets_,
189           // Done callback
190           [this](Status s) {
191             if (s.ok()) {
192               s = SetOutputs(kernel_, ctx_, rets_);
193             }
194             ctx_->SetStatus(s);
195             DoneCallback captured_done(std::move(done_));
196             delete this;
197             captured_done();
198           });
199     }
200 
201    private:
202     IfOp* const kernel_;
203     OpKernelContext* const ctx_;
204     const bool cond_;
205     FHandle then_handle_;
206     FHandle else_handle_;
207     DoneCallback done_;
208     FunctionLibraryRuntime* const lib_;
209     FunctionLibraryRuntime::Options opts_;
210     TensorVec args_;
211     TensorVec rets_;
212   };
213 };
214 
215 class CaseOp : public AsyncOpKernel {
216  public:
CaseOp(OpKernelConstruction * ctx)217   explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
218     auto lib = ctx->function_library();
219     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
221   }
222 
~CaseOp()223   ~CaseOp() override {}
224 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)225   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
226     auto lib = ctx->function_library();
227     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
228                       errors::Internal("No function library"), done);
229 
230     // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
231     // registration, this kernel may be shared by multiple subgraphs, which have
232     // different associated `FunctionLibraryRuntime` objects and hence different
233     // `FHandle` namespaces. So we must call Instantiate() to make sure we get
234     // the correct function handles with respect to `lib`. Note the underlying
235     // `lib->Instantiate()` caches the created function handles, so calling
236     // `Instantiate()` repeatedly on the same `lib` and function is cheap.
237     std::vector<FHandle> branch_handles(branch_funcs_.size());
238     for (int i = 0; i < branch_funcs_.size(); i++) {
239       OP_REQUIRES_OK_ASYNC(
240           ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
241     }
242 
243     const Tensor& branch_index = ctx->input(0);
244     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
245                       errors::InvalidArgument("branch_index must be scalar"),
246                       done);
247     int32 branch = branch_index.scalar<int32>()();
248     (new State(this, ctx, branch, branch_handles, done))->Start();
249   }
250 
251  private:
252   std::vector<NameAttrList> branch_funcs_;
253 
254   class State {
255    public:
State(CaseOp * kernel,OpKernelContext * ctx,int branch,std::vector<FHandle> branch_handles,DoneCallback done)256     State(CaseOp* kernel, OpKernelContext* ctx, int branch,
257           std::vector<FHandle> branch_handles, DoneCallback done)
258         : kernel_(kernel),
259           ctx_(ctx),
260           branch_(branch),
261           branch_handles_(branch_handles),
262           done_(std::move(done)),
263           lib_(CHECK_NOTNULL(ctx_->function_library())) {
264       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
265       for (int i = 1; i < ctx_->num_inputs(); ++i) {
266         args_.push_back(ctx_->input(i));
267       }
268     }
269 
~State()270     ~State() {}
271 
Start()272     void Start() {
273       int branch = branch_;
274       // The last branch is the default branch.
275       if (branch < 0 || branch >= branch_handles_.size()) {
276         branch = branch_handles_.size() - 1;
277       }
278       rets_.clear();
279       lib_->Run(
280           // Evaluate one of the branch.
281           opts_, branch_handles_[branch], args_, &rets_,
282           // Done callback
283           [this](Status s) {
284             if (s.ok()) {
285               s = SetOutputs(kernel_, ctx_, rets_);
286             }
287             ctx_->SetStatus(s);
288             DoneCallback captured_done(std::move(done_));
289             delete this;
290             captured_done();
291           });
292     }
293 
294    private:
295     CaseOp* const kernel_;
296     OpKernelContext* const ctx_;
297     const int branch_;
298     std::vector<FHandle> branch_handles_;
299     DoneCallback done_;
300     FunctionLibraryRuntime* const lib_;
301     FunctionLibraryRuntime::Options opts_;
302     TensorVec args_;
303     TensorVec rets_;
304   };
305 };
306 
307 // TODO(drpng): remove this.
308 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
309 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
310                         IfOp);
311 
312 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
313 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
314 
315 REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
316 REGISTER_KERNEL_BUILDER(
317     Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
318 
319 REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
320 REGISTER_KERNEL_BUILDER(
321     Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
322 
323 class WhileOp : public AsyncOpKernel {
324  public:
WhileOp(OpKernelConstruction * ctx)325   explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
326     OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
327     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
328   }
329 
~WhileOp()330   ~WhileOp() override {}
331 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)332   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
333     auto lib = ctx->function_library();
334     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
335                       errors::Internal("No function library"), done);
336 
337     // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
338     // registration, this kernel may be shared by multiple subgraphs, which have
339     // different associated `FunctionLibraryRuntime` objects and hence different
340     // `FHandle` namespaces. So we must call Instantiate() to make sure we get
341     // the correct function handles with respect to `lib`. Note the underlying
342     // `lib->Instantiate()` caches the created function handles, so calling
343     // `Instantiate()` repeatedly on the same `lib` and function is cheap.
344     FHandle cond_handle;
345     FHandle body_handle;
346     OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
347     OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
348     (new State(this, ctx, cond_handle, body_handle, done))->Start();
349   }
350 
351  private:
352   NameAttrList cond_func_;
353   NameAttrList body_func_;
354 
355   class State {
356    public:
State(WhileOp * kernel,OpKernelContext * ctx,FHandle cond_handle,FHandle body_handle,DoneCallback done)357     State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
358           FHandle body_handle, DoneCallback done)
359         : kernel_(kernel),
360           ctx_(ctx),
361           cond_handle_(cond_handle),
362           body_handle_(body_handle),
363           done_(std::move(done)),
364           lib_(CHECK_NOTNULL(ctx_->function_library())) {
365       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
366       for (int i = 0; i < ctx_->num_inputs(); ++i) {
367         args_.push_back(ctx_->input(i));
368       }
369     }
370 
~State()371     ~State() {}
372 
Start()373     void Start() { EvalCond(); }
374 
375    private:
376     WhileOp* const kernel_;
377     OpKernelContext* const ctx_;
378     const FHandle cond_handle_;
379     const FHandle body_handle_;
380     const DoneCallback done_;
381     FunctionLibraryRuntime* const lib_;
382     FunctionLibraryRuntime::Options opts_;
383     TensorVec args_;
384     TensorVec rets_;
385 
EvalCond()386     void EvalCond() {
387       lib_->Run(
388           // Evaluate the condition.
389           opts_, cond_handle_, args_, &rets_,
390           // Done cb.
391           [this](const Status& s) {
392             if (!s.ok()) {
393               return Finish(s);
394             }
395             StartBody();
396           });
397     }
398 
StartBody()399     void StartBody() {
400       Status s;
401       if (rets_.size() != 1) {
402         s = errors::InvalidArgument(
403             "Expected a single scalar return value from WhileOp cond, got ",
404             rets_.size(), " tensors.");
405         return Finish(s);
406       }
407       Tensor cond_t;
408 #if GOOGLE_CUDA
409       const DeviceBase::GpuDeviceInfo* gpu_device_info =
410           ctx_->device()->tensorflow_gpu_device_info();
411       const bool is_hostmem_dtype =
412           rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
413       if (!is_hostmem_dtype && gpu_device_info &&
414           (opts_.rets_alloc_attrs.empty() ||
415            !opts_.rets_alloc_attrs[0].on_host())) {
416         // Copy the ret value to host if it's allocated on device.
417         Device* device = down_cast<Device*>(ctx_->device());
418         DeviceContext* device_ctx = ctx_->op_device_context();
419         cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
420         Notification done_copy;
421         device_ctx->CopyDeviceTensorToCPU(
422             &rets_[0], /*tensor_name=*/"", device, &cond_t,
423             [&done_copy, &s](const Status& status) {
424               s = status;
425               done_copy.Notify();
426             });
427         done_copy.WaitForNotification();
428         if (!s.ok()) {
429           return Finish(s);
430         }
431       } else {
432         cond_t = rets_[0];
433       }
434 #else
435       cond_t = rets_[0];
436 #endif
437       bool cond;
438       s = ToBool({cond_t}, &cond);
439 
440       if (!s.ok()) {
441         return Finish(s);
442       }
443       if (!cond) {
444         return Finish(Status::OK());
445       }
446       rets_.clear();
447       lib_->Run(
448           // Evaluate the body.
449           opts_, body_handle_, args_, &rets_,
450           // Done callback
451           [this](const Status& s) {
452             if (!s.ok()) {
453               return Finish(s);
454             }
455             if (args_.size() != rets_.size()) {
456               return Finish(errors::InvalidArgument(
457                   "While loop body returned ", rets_.size(),
458                   " arguments. Expected: ", args_.size()));
459             }
460             args_.clear();
461             using std::swap;
462             swap(args_, rets_);
463             EvalCond();
464           });
465     }
466 
Finish(Status s)467     void Finish(Status s) {
468       if (s.ok()) {
469         s = SetOutputs(kernel_, ctx_, args_);
470       }
471       ctx_->SetStatus(s);
472       done_();
473       delete this;
474     }
475   };
476 };
477 // TODO(drpng): remove these.
478 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
479 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
480 
481 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
482 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
483 
484 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
485 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
486 
GetScalar(OpKernelContext * ctx,int index,int32 * value,const char * label)487 Status GetScalar(OpKernelContext* ctx, int index, int32* value,
488                  const char* label) {
489   Tensor t = ctx->input(index);
490   if (!TensorShapeUtils::IsScalar(t.shape())) {
491     return errors::InvalidArgument(label, " must be a scalar, but ",
492                                    t.shape().DebugString());
493   }
494   *value = t.scalar<int32>()();
495   return Status::OK();
496 }
497 
498 class ForOp : public AsyncOpKernel {
499  public:
ForOp(OpKernelConstruction * ctx)500   explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
501     auto lib = ctx->function_library();
502     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
503     const NameAttrList* func;
504     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
505     OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
506   }
507 
~ForOp()508   ~ForOp() override {}
509 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)510   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
511     (new State(this, ctx, done))->Start();
512   }
513 
514  private:
515   FHandle body_handle_;
516 
517   class State {
518    public:
State(ForOp * kernel,OpKernelContext * ctx,DoneCallback done)519     State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
520         : kernel_(kernel),
521           ctx_(ctx),
522           done_(std::move(done)),
523           lib_(CHECK_NOTNULL(ctx_->function_library())),
524           args_(1 + ctx_->num_inputs() - 3) {
525       args_[0] = Tensor(DT_INT32, {});
526       iter_ = &args_[0].scalar<int32>()();
527 
528       const int32 num_loop_inputs = ctx_->num_inputs() - 3;
529       rets_.reserve(num_loop_inputs);
530       for (int i = 0; i < num_loop_inputs; ++i) {
531         rets_.push_back(ctx_->input(3 + i));
532       }
533     }
534 
~State()535     ~State() {}
536 
Start()537     void Start() {
538       Status s = StartLoop();
539       if (!s.ok()) Finish(s);
540     }
541 
542    private:
543     ForOp* const kernel_;
544     OpKernelContext* const ctx_;
545     const DoneCallback done_;
546     FunctionLibraryRuntime* const lib_;
547     FunctionLibraryRuntime::Options opts_;
548     TensorVec args_;
549     TensorVec rets_;
550 
551     int32* iter_;  // points to args_[0].
552     int32 limit_;
553     int32 delta_;
554 
555     // If an error e is returned, caller must call Finish(e).
556     // If OK is returned, the async loop execution has been started.
StartLoop()557     Status StartLoop() {
558       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
559 
560       TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
561       TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
562       TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
563 
564       if ((delta_ > 0 && *iter_ <= limit_) ||
565           (delta_ < 0 && *iter_ >= limit_) ||
566           (delta_ == 0 && *iter_ == limit_)) {
567         RunNext();
568         return Status::OK();
569       } else {
570         return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
571                                        " ", limit_, " ", delta_);
572       }
573     }
574 
RunNext()575     void RunNext() {
576       bool done_loop;
577       if (delta_ > 0) {
578         done_loop = *iter_ >= limit_;
579       } else {
580         done_loop = *iter_ <= limit_;
581       }
582       if (done_loop) {
583         Finish(Status::OK());
584         return;
585       }
586 
587       if (rets_.size() >= args_.size()) {
588         Finish(errors::InvalidArgument(
589             "For loop body returned ", rets_.size(),
590             " arguments. Expected: ", args_.size() - 1));
591         return;
592       }
593       for (int i = 0; i < rets_.size(); ++i) {
594         args_[1 + i] = std::move(rets_[i]);
595       }
596       rets_.clear();
597       lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
598                 [this](const Status& s) {
599                   if (s.ok()) {
600                     *iter_ += delta_;
601                     RunNext();
602                   } else {
603                     Finish(s);
604                   }
605                 });
606     }
607 
Finish(Status s)608     void Finish(Status s) {
609       if (s.ok()) {
610         s = SetOutputs(kernel_, ctx_, rets_);
611       }
612       ctx_->SetStatus(s);
613       done_();
614       delete this;
615     }
616   };
617 };
618 
619 REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
620 REGISTER_KERNEL_BUILDER(Name("For")
621                             .Device(DEVICE_GPU)
622                             .HostMemory("start")
623                             .HostMemory("limit")
624                             .HostMemory("delta"),
625                         ForOp);
626 
627 // FakeParamOp allocates a tensor with a shape conforming to the expected
628 // output. This is necessary if the value will be stored in a while_loop's
629 // TensorList. The output is otherwise not expected to be consumed by anything
630 // else.
631 class FakeParamOp : public OpKernel {
632  public:
FakeParamOp(OpKernelConstruction * context)633   explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
634     DataType dtype;
635     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
636 
637     // Set shape to the specified shape, setting unknown dimensions to empty.
638     // If the specified shape is unknown, leave as an empty shape.
639     TensorShape shape;
640     PartialTensorShape partial_shape;
641     OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
642     if (!partial_shape.unknown_rank()) {
643       for (int64 d : partial_shape.dim_sizes()) {
644         shape.AddDim(d == -1 ? 0 : d);
645       }
646     }
647 
648     // Create a persistent tensor that we can repeatedly return to save memory.
649     // TODO(b/119612758): add optimization to prevent sending this across
650     // devices on each Compute() call.
651     OP_REQUIRES_OK(context, context->allocate_persistent(
652                                 dtype, shape, &value_handle_, nullptr));
653   }
654 
Compute(OpKernelContext * context)655   void Compute(OpKernelContext* context) override {
656     context->set_output(0, *value_handle_.AccessTensor(context));
657   }
658 
659  private:
660   PersistentTensor value_handle_;
661 };
662 
663 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
664 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
665 
666 }  // namespace
667 }  // namespace tensorflow
668