1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
18 #if GOOGLE_CUDA
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/device_base.h"
21 #endif
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28
29 namespace tensorflow {
30 typedef Eigen::GpuDevice GPUDevice;
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef FunctionLibraryRuntime::Handle FHandle;
33 typedef std::vector<Tensor> TensorVec;
34
35 namespace {
36
37 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)38 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
39 FunctionLibraryRuntime::Handle* handle) {
40 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
41 }
42
43 template <typename To, typename From> // use like this: down_cast<T*>(foo);
down_cast(From * f)44 inline To down_cast(From* f) { // so we only accept pointers
45 static_assert(
46 (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
47 "target type not derived from source type");
48
49 // We skip the assert and hence the dynamic_cast if RTTI is disabled.
50 #if !defined(__GNUC__) || defined(__GXX_RTTI)
51 // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
52 assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
53 #endif // !defined(__GNUC__) || defined(__GXX_RTTI)
54
55 return static_cast<To>(f);
56 }
57
58 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
ToBool(gtl::ArraySlice<Tensor> t,bool * v)59 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
60 if (t.size() != 1) {
61 return errors::InvalidArgument(
62 "Expected a single scalar which can be converted to a boolean, got ",
63 t.size(), " tensors.");
64 }
65 if (TensorShapeUtils::IsScalar(t[0].shape())) {
66 switch (t[0].dtype()) {
67 #define CASE(T) \
68 case DataTypeToEnum<T>::value: \
69 *v = t[0].scalar<T>()() != 0; \
70 break;
71
72 CASE(float);
73 CASE(double);
74 CASE(int32);
75 CASE(uint8);
76 CASE(int16);
77 CASE(int8);
78 CASE(int64);
79 #undef CASE
80 case DT_BOOL:
81 *v = t[0].scalar<bool>()();
82 break;
83 case DT_STRING:
84 *v = !t[0].scalar<string>()().empty();
85 break;
86 default:
87 return errors::InvalidArgument(DataTypeString(t[0].dtype()),
88 " cannot be converted to a boolean");
89 }
90 } else {
91 *v = t[0].NumElements() > 0;
92 }
93 return Status::OK();
94 }
95
96 // Sets "rets" to be the output of "ctx". Validates rets' types based
97 // on "kernel".
SetOutputs(const OpKernel * kernel,OpKernelContext * ctx,gtl::ArraySlice<Tensor> rets)98 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
99 gtl::ArraySlice<Tensor> rets) {
100 if (rets.size() != ctx->num_outputs()) {
101 return errors::Internal("Expect to produce ", ctx->num_outputs(),
102 " tensors, but only get ", rets.size());
103 }
104 for (int i = 0; i < rets.size(); ++i) {
105 if (rets[i].dtype() != kernel->output_type(i)) {
106 return errors::Internal("Expect ", i, "-th output is of type ",
107 DataTypeString(kernel->output_type(i)),
108 " but get ", DataTypeString(rets[i].dtype()));
109 }
110 ctx->set_output(i, rets[i]);
111 }
112 return Status::OK();
113 }
114
SetRunOptions(OpKernelContext * ctx,FunctionLibraryRuntime::Options * opts,bool always_collect_stats)115 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
116 bool always_collect_stats) {
117 opts->step_id = ctx->step_id();
118 opts->rendezvous = ctx->rendezvous();
119 opts->cancellation_manager = ctx->cancellation_manager();
120 if (always_collect_stats) {
121 opts->stats_collector = ctx->stats_collector();
122 }
123 opts->runner = ctx->runner();
124 opts->step_container = ctx->step_container();
125 }
126
127 class IfOp : public AsyncOpKernel {
128 public:
IfOp(OpKernelConstruction * ctx)129 explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
130 auto lib = ctx->function_library();
131 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
132 OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
133 OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
134 }
135
~IfOp()136 ~IfOp() override {}
137
ComputeAsync(OpKernelContext * ctx,DoneCallback done)138 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
139 auto lib = ctx->function_library();
140 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
141 errors::Internal("No function library"), done);
142
143 // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
144 // registration, this kernel may be shared by multiple subgraphs, which have
145 // different associated `FunctionLibraryRuntime` objects and hence different
146 // `FHandle` namespaces. So we must call Instantiate() to make sure we get
147 // the correct function handles with respect to `lib`. Note the underlying
148 // `lib->Instantiate()` caches the created function handles, so calling
149 // `Instantiate()` repeatedly on the same `lib` and function is cheap.
150 FHandle then_handle;
151 FHandle else_handle;
152 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
153 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
154
155 bool cond;
156 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
157 (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
158 }
159
160 private:
161 NameAttrList then_func_;
162 NameAttrList else_func_;
163
164 class State {
165 public:
State(IfOp * kernel,OpKernelContext * ctx,bool cond,FHandle then_handle,FHandle else_handle,DoneCallback done)166 State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
167 FHandle else_handle, DoneCallback done)
168 : kernel_(kernel),
169 ctx_(ctx),
170 cond_(cond),
171 then_handle_(then_handle),
172 else_handle_(else_handle),
173 done_(std::move(done)),
174 lib_(CHECK_NOTNULL(ctx_->function_library())) {
175 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
176 for (int i = 1; i < ctx_->num_inputs(); ++i) {
177 args_.push_back(ctx_->input(i));
178 }
179 }
180
~State()181 ~State() {}
182
Start()183 void Start() {
184 FHandle handle = cond_ ? then_handle_ : else_handle_;
185 rets_.clear();
186 lib_->Run(
187 // Evaluate one of the branch.
188 opts_, handle, args_, &rets_,
189 // Done callback
190 [this](Status s) {
191 if (s.ok()) {
192 s = SetOutputs(kernel_, ctx_, rets_);
193 }
194 ctx_->SetStatus(s);
195 DoneCallback captured_done(std::move(done_));
196 delete this;
197 captured_done();
198 });
199 }
200
201 private:
202 IfOp* const kernel_;
203 OpKernelContext* const ctx_;
204 const bool cond_;
205 FHandle then_handle_;
206 FHandle else_handle_;
207 DoneCallback done_;
208 FunctionLibraryRuntime* const lib_;
209 FunctionLibraryRuntime::Options opts_;
210 TensorVec args_;
211 TensorVec rets_;
212 };
213 };
214
215 class CaseOp : public AsyncOpKernel {
216 public:
CaseOp(OpKernelConstruction * ctx)217 explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
218 auto lib = ctx->function_library();
219 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
220 OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
221 }
222
~CaseOp()223 ~CaseOp() override {}
224
ComputeAsync(OpKernelContext * ctx,DoneCallback done)225 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
226 auto lib = ctx->function_library();
227 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
228 errors::Internal("No function library"), done);
229
230 // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
231 // registration, this kernel may be shared by multiple subgraphs, which have
232 // different associated `FunctionLibraryRuntime` objects and hence different
233 // `FHandle` namespaces. So we must call Instantiate() to make sure we get
234 // the correct function handles with respect to `lib`. Note the underlying
235 // `lib->Instantiate()` caches the created function handles, so calling
236 // `Instantiate()` repeatedly on the same `lib` and function is cheap.
237 std::vector<FHandle> branch_handles(branch_funcs_.size());
238 for (int i = 0; i < branch_funcs_.size(); i++) {
239 OP_REQUIRES_OK_ASYNC(
240 ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
241 }
242
243 const Tensor& branch_index = ctx->input(0);
244 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
245 errors::InvalidArgument("branch_index must be scalar"),
246 done);
247 int32 branch = branch_index.scalar<int32>()();
248 (new State(this, ctx, branch, branch_handles, done))->Start();
249 }
250
251 private:
252 std::vector<NameAttrList> branch_funcs_;
253
254 class State {
255 public:
State(CaseOp * kernel,OpKernelContext * ctx,int branch,std::vector<FHandle> branch_handles,DoneCallback done)256 State(CaseOp* kernel, OpKernelContext* ctx, int branch,
257 std::vector<FHandle> branch_handles, DoneCallback done)
258 : kernel_(kernel),
259 ctx_(ctx),
260 branch_(branch),
261 branch_handles_(branch_handles),
262 done_(std::move(done)),
263 lib_(CHECK_NOTNULL(ctx_->function_library())) {
264 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
265 for (int i = 1; i < ctx_->num_inputs(); ++i) {
266 args_.push_back(ctx_->input(i));
267 }
268 }
269
~State()270 ~State() {}
271
Start()272 void Start() {
273 int branch = branch_;
274 // The last branch is the default branch.
275 if (branch < 0 || branch >= branch_handles_.size()) {
276 branch = branch_handles_.size() - 1;
277 }
278 rets_.clear();
279 lib_->Run(
280 // Evaluate one of the branch.
281 opts_, branch_handles_[branch], args_, &rets_,
282 // Done callback
283 [this](Status s) {
284 if (s.ok()) {
285 s = SetOutputs(kernel_, ctx_, rets_);
286 }
287 ctx_->SetStatus(s);
288 DoneCallback captured_done(std::move(done_));
289 delete this;
290 captured_done();
291 });
292 }
293
294 private:
295 CaseOp* const kernel_;
296 OpKernelContext* const ctx_;
297 const int branch_;
298 std::vector<FHandle> branch_handles_;
299 DoneCallback done_;
300 FunctionLibraryRuntime* const lib_;
301 FunctionLibraryRuntime::Options opts_;
302 TensorVec args_;
303 TensorVec rets_;
304 };
305 };
306
307 // TODO(drpng): remove this.
308 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
309 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
310 IfOp);
311
312 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
313 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
314
315 REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
316 REGISTER_KERNEL_BUILDER(
317 Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
318
319 REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
320 REGISTER_KERNEL_BUILDER(
321 Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
322
323 class WhileOp : public AsyncOpKernel {
324 public:
WhileOp(OpKernelConstruction * ctx)325 explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
326 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
327 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
328 }
329
~WhileOp()330 ~WhileOp() override {}
331
ComputeAsync(OpKernelContext * ctx,DoneCallback done)332 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
333 auto lib = ctx->function_library();
334 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
335 errors::Internal("No function library"), done);
336
337 // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
338 // registration, this kernel may be shared by multiple subgraphs, which have
339 // different associated `FunctionLibraryRuntime` objects and hence different
340 // `FHandle` namespaces. So we must call Instantiate() to make sure we get
341 // the correct function handles with respect to `lib`. Note the underlying
342 // `lib->Instantiate()` caches the created function handles, so calling
343 // `Instantiate()` repeatedly on the same `lib` and function is cheap.
344 FHandle cond_handle;
345 FHandle body_handle;
346 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
347 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
348 (new State(this, ctx, cond_handle, body_handle, done))->Start();
349 }
350
351 private:
352 NameAttrList cond_func_;
353 NameAttrList body_func_;
354
355 class State {
356 public:
State(WhileOp * kernel,OpKernelContext * ctx,FHandle cond_handle,FHandle body_handle,DoneCallback done)357 State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
358 FHandle body_handle, DoneCallback done)
359 : kernel_(kernel),
360 ctx_(ctx),
361 cond_handle_(cond_handle),
362 body_handle_(body_handle),
363 done_(std::move(done)),
364 lib_(CHECK_NOTNULL(ctx_->function_library())) {
365 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
366 for (int i = 0; i < ctx_->num_inputs(); ++i) {
367 args_.push_back(ctx_->input(i));
368 }
369 }
370
~State()371 ~State() {}
372
Start()373 void Start() { EvalCond(); }
374
375 private:
376 WhileOp* const kernel_;
377 OpKernelContext* const ctx_;
378 const FHandle cond_handle_;
379 const FHandle body_handle_;
380 const DoneCallback done_;
381 FunctionLibraryRuntime* const lib_;
382 FunctionLibraryRuntime::Options opts_;
383 TensorVec args_;
384 TensorVec rets_;
385
EvalCond()386 void EvalCond() {
387 lib_->Run(
388 // Evaluate the condition.
389 opts_, cond_handle_, args_, &rets_,
390 // Done cb.
391 [this](const Status& s) {
392 if (!s.ok()) {
393 return Finish(s);
394 }
395 StartBody();
396 });
397 }
398
StartBody()399 void StartBody() {
400 Status s;
401 if (rets_.size() != 1) {
402 s = errors::InvalidArgument(
403 "Expected a single scalar return value from WhileOp cond, got ",
404 rets_.size(), " tensors.");
405 return Finish(s);
406 }
407 Tensor cond_t;
408 #if GOOGLE_CUDA
409 const DeviceBase::GpuDeviceInfo* gpu_device_info =
410 ctx_->device()->tensorflow_gpu_device_info();
411 const bool is_hostmem_dtype =
412 rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
413 if (!is_hostmem_dtype && gpu_device_info &&
414 (opts_.rets_alloc_attrs.empty() ||
415 !opts_.rets_alloc_attrs[0].on_host())) {
416 // Copy the ret value to host if it's allocated on device.
417 Device* device = down_cast<Device*>(ctx_->device());
418 DeviceContext* device_ctx = ctx_->op_device_context();
419 cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
420 Notification done_copy;
421 device_ctx->CopyDeviceTensorToCPU(
422 &rets_[0], /*tensor_name=*/"", device, &cond_t,
423 [&done_copy, &s](const Status& status) {
424 s = status;
425 done_copy.Notify();
426 });
427 done_copy.WaitForNotification();
428 if (!s.ok()) {
429 return Finish(s);
430 }
431 } else {
432 cond_t = rets_[0];
433 }
434 #else
435 cond_t = rets_[0];
436 #endif
437 bool cond;
438 s = ToBool({cond_t}, &cond);
439
440 if (!s.ok()) {
441 return Finish(s);
442 }
443 if (!cond) {
444 return Finish(Status::OK());
445 }
446 rets_.clear();
447 lib_->Run(
448 // Evaluate the body.
449 opts_, body_handle_, args_, &rets_,
450 // Done callback
451 [this](const Status& s) {
452 if (!s.ok()) {
453 return Finish(s);
454 }
455 if (args_.size() != rets_.size()) {
456 return Finish(errors::InvalidArgument(
457 "While loop body returned ", rets_.size(),
458 " arguments. Expected: ", args_.size()));
459 }
460 args_.clear();
461 using std::swap;
462 swap(args_, rets_);
463 EvalCond();
464 });
465 }
466
Finish(Status s)467 void Finish(Status s) {
468 if (s.ok()) {
469 s = SetOutputs(kernel_, ctx_, args_);
470 }
471 ctx_->SetStatus(s);
472 done_();
473 delete this;
474 }
475 };
476 };
477 // TODO(drpng): remove these.
478 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
479 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
480
481 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
482 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
483
484 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
485 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
486
GetScalar(OpKernelContext * ctx,int index,int32 * value,const char * label)487 Status GetScalar(OpKernelContext* ctx, int index, int32* value,
488 const char* label) {
489 Tensor t = ctx->input(index);
490 if (!TensorShapeUtils::IsScalar(t.shape())) {
491 return errors::InvalidArgument(label, " must be a scalar, but ",
492 t.shape().DebugString());
493 }
494 *value = t.scalar<int32>()();
495 return Status::OK();
496 }
497
498 class ForOp : public AsyncOpKernel {
499 public:
ForOp(OpKernelConstruction * ctx)500 explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
501 auto lib = ctx->function_library();
502 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
503 const NameAttrList* func;
504 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
505 OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
506 }
507
~ForOp()508 ~ForOp() override {}
509
ComputeAsync(OpKernelContext * ctx,DoneCallback done)510 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
511 (new State(this, ctx, done))->Start();
512 }
513
514 private:
515 FHandle body_handle_;
516
517 class State {
518 public:
State(ForOp * kernel,OpKernelContext * ctx,DoneCallback done)519 State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
520 : kernel_(kernel),
521 ctx_(ctx),
522 done_(std::move(done)),
523 lib_(CHECK_NOTNULL(ctx_->function_library())),
524 args_(1 + ctx_->num_inputs() - 3) {
525 args_[0] = Tensor(DT_INT32, {});
526 iter_ = &args_[0].scalar<int32>()();
527
528 const int32 num_loop_inputs = ctx_->num_inputs() - 3;
529 rets_.reserve(num_loop_inputs);
530 for (int i = 0; i < num_loop_inputs; ++i) {
531 rets_.push_back(ctx_->input(3 + i));
532 }
533 }
534
~State()535 ~State() {}
536
Start()537 void Start() {
538 Status s = StartLoop();
539 if (!s.ok()) Finish(s);
540 }
541
542 private:
543 ForOp* const kernel_;
544 OpKernelContext* const ctx_;
545 const DoneCallback done_;
546 FunctionLibraryRuntime* const lib_;
547 FunctionLibraryRuntime::Options opts_;
548 TensorVec args_;
549 TensorVec rets_;
550
551 int32* iter_; // points to args_[0].
552 int32 limit_;
553 int32 delta_;
554
555 // If an error e is returned, caller must call Finish(e).
556 // If OK is returned, the async loop execution has been started.
StartLoop()557 Status StartLoop() {
558 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
559
560 TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
561 TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
562 TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
563
564 if ((delta_ > 0 && *iter_ <= limit_) ||
565 (delta_ < 0 && *iter_ >= limit_) ||
566 (delta_ == 0 && *iter_ == limit_)) {
567 RunNext();
568 return Status::OK();
569 } else {
570 return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
571 " ", limit_, " ", delta_);
572 }
573 }
574
RunNext()575 void RunNext() {
576 bool done_loop;
577 if (delta_ > 0) {
578 done_loop = *iter_ >= limit_;
579 } else {
580 done_loop = *iter_ <= limit_;
581 }
582 if (done_loop) {
583 Finish(Status::OK());
584 return;
585 }
586
587 if (rets_.size() >= args_.size()) {
588 Finish(errors::InvalidArgument(
589 "For loop body returned ", rets_.size(),
590 " arguments. Expected: ", args_.size() - 1));
591 return;
592 }
593 for (int i = 0; i < rets_.size(); ++i) {
594 args_[1 + i] = std::move(rets_[i]);
595 }
596 rets_.clear();
597 lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
598 [this](const Status& s) {
599 if (s.ok()) {
600 *iter_ += delta_;
601 RunNext();
602 } else {
603 Finish(s);
604 }
605 });
606 }
607
Finish(Status s)608 void Finish(Status s) {
609 if (s.ok()) {
610 s = SetOutputs(kernel_, ctx_, rets_);
611 }
612 ctx_->SetStatus(s);
613 done_();
614 delete this;
615 }
616 };
617 };
618
619 REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
620 REGISTER_KERNEL_BUILDER(Name("For")
621 .Device(DEVICE_GPU)
622 .HostMemory("start")
623 .HostMemory("limit")
624 .HostMemory("delta"),
625 ForOp);
626
627 // FakeParamOp allocates a tensor with a shape conforming to the expected
628 // output. This is necessary if the value will be stored in a while_loop's
629 // TensorList. The output is otherwise not expected to be consumed by anything
630 // else.
631 class FakeParamOp : public OpKernel {
632 public:
FakeParamOp(OpKernelConstruction * context)633 explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
634 DataType dtype;
635 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
636
637 // Set shape to the specified shape, setting unknown dimensions to empty.
638 // If the specified shape is unknown, leave as an empty shape.
639 TensorShape shape;
640 PartialTensorShape partial_shape;
641 OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
642 if (!partial_shape.unknown_rank()) {
643 for (int64 d : partial_shape.dim_sizes()) {
644 shape.AddDim(d == -1 ? 0 : d);
645 }
646 }
647
648 // Create a persistent tensor that we can repeatedly return to save memory.
649 // TODO(b/119612758): add optimization to prevent sending this across
650 // devices on each Compute() call.
651 OP_REQUIRES_OK(context, context->allocate_persistent(
652 dtype, shape, &value_handle_, nullptr));
653 }
654
Compute(OpKernelContext * context)655 void Compute(OpKernelContext* context) override {
656 context->set_output(0, *value_handle_.AccessTensor(context));
657 }
658
659 private:
660 PersistentTensor value_handle_;
661 };
662
663 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
664 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
665
666 } // namespace
667 } // namespace tensorflow
668