1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/types.h"
16 #define EIGEN_USE_THREADS
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #endif
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33
34 namespace tensorflow {
35 typedef Eigen::GpuDevice GPUDevice;
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef FunctionLibraryRuntime::Handle FHandle;
38 typedef std::vector<Tensor> TensorVec;
39
40 namespace {
41
42 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)43 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
44 FunctionLibraryRuntime::Handle* handle) {
45 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
46 }
47
Instantiate(OpKernelContext * ctx,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)48 Status Instantiate(OpKernelContext* ctx, const NameAttrList& func,
49 FunctionLibraryRuntime::Handle* handle) {
50 FunctionLibraryRuntime::InstantiateOptions opts;
51 opts.executor_type = ctx->executor_type();
52 return ctx->function_library()->Instantiate(
53 func.name(), AttrSlice(&func.attr()), opts, handle);
54 }
55
56 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
ToBool(gtl::ArraySlice<Tensor> t,bool * v)57 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
58 if (t.size() != 1) {
59 return errors::InvalidArgument(
60 "Expected a single scalar which can be converted to a boolean, got ",
61 t.size(), " tensors.");
62 }
63 if (TensorShapeUtils::IsScalar(t[0].shape())) {
64 switch (t[0].dtype()) {
65 #define CASE(T) \
66 case DataTypeToEnum<T>::value: \
67 *v = t[0].scalar<T>()() != 0; \
68 break;
69
70 CASE(float);
71 CASE(double);
72 CASE(int32);
73 CASE(uint8);
74 CASE(int16);
75 CASE(int8);
76 CASE(int64);
77 #undef CASE
78 case DT_BOOL:
79 *v = t[0].scalar<bool>()();
80 break;
81 case DT_STRING:
82 *v = !t[0].scalar<tstring>()().empty();
83 break;
84 default:
85 return errors::InvalidArgument(DataTypeString(t[0].dtype()),
86 " cannot be converted to a boolean");
87 }
88 } else {
89 *v = t[0].NumElements() > 0;
90 }
91 return Status::OK();
92 }
93
94 // Sets "rets" to be the output of "ctx". Validates rets' types based
95 // on "kernel".
SetOutputs(const OpKernel * kernel,OpKernelContext * ctx,gtl::ArraySlice<Tensor> rets)96 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
97 gtl::ArraySlice<Tensor> rets) {
98 if (rets.size() != ctx->num_outputs()) {
99 return errors::Internal("Expect to produce ", ctx->num_outputs(),
100 " tensors, but only get ", rets.size());
101 }
102 for (int i = 0; i < rets.size(); ++i) {
103 if (rets[i].dtype() != kernel->output_type(i)) {
104 return errors::Internal("Expect ", i, "-th output is of type ",
105 DataTypeString(kernel->output_type(i)),
106 " but get ", DataTypeString(rets[i].dtype()));
107 }
108 ctx->set_output(i, rets[i]);
109 }
110 return Status::OK();
111 }
112
SetRunOptions(OpKernelContext * ctx,FunctionLibraryRuntime::Options * opts,bool always_collect_stats)113 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
114 bool always_collect_stats) {
115 opts->rendezvous = ctx->rendezvous();
116 opts->cancellation_manager = ctx->cancellation_manager();
117 opts->collective_executor = ctx->collective_executor();
118 if (always_collect_stats) {
119 opts->stats_collector = ctx->stats_collector();
120 }
121 opts->runner = ctx->runner();
122 opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
123 opts->step_container = ctx->step_container();
124 }
125
126 class IfOp : public AsyncOpKernel {
127 public:
IfOp(OpKernelConstruction * ctx)128 explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
129 auto lib = ctx->function_library();
130 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
131 OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
132 OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
133 }
134
~IfOp()135 ~IfOp() override {}
136
ComputeAsync(OpKernelContext * ctx,DoneCallback done)137 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
138 FHandle then_handle;
139 FHandle else_handle;
140 OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
141 done);
142 bool cond;
143 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
144 (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
145 }
146
147 private:
148 NameAttrList then_func_;
149 NameAttrList else_func_;
150
151 mutex mu_;
152 std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
153 handles_ ABSL_GUARDED_BY(mu_);
154
155 class State {
156 public:
State(IfOp * kernel,OpKernelContext * ctx,bool cond,FHandle then_handle,FHandle else_handle,DoneCallback done)157 State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
158 FHandle else_handle, DoneCallback done)
159 : kernel_(kernel),
160 ctx_(ctx),
161 cond_(cond),
162 then_handle_(then_handle),
163 else_handle_(else_handle),
164 done_(std::move(done)),
165 lib_(CHECK_NOTNULL(ctx_->function_library())) {
166 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
167 for (int i = 1; i < ctx_->num_inputs(); ++i) {
168 args_.push_back(ctx_->input(i));
169 }
170 }
171
~State()172 ~State() {}
173
Start()174 void Start() {
175 FHandle handle = cond_ ? then_handle_ : else_handle_;
176 rets_.clear();
177 profiler::TraceMe trace_me("IfOp");
178 lib_->Run(
179 // Evaluate one of the branch.
180 opts_, handle, args_, &rets_,
181 // Done callback
182 [this](Status s) {
183 if (s.ok()) {
184 s = SetOutputs(kernel_, ctx_, rets_);
185 }
186 ctx_->SetStatus(s);
187 DoneCallback captured_done(std::move(done_));
188 delete this;
189 captured_done();
190 });
191 }
192
193 private:
194 IfOp* const kernel_;
195 OpKernelContext* const ctx_;
196 const bool cond_;
197 FHandle then_handle_;
198 FHandle else_handle_;
199 DoneCallback done_;
200 FunctionLibraryRuntime* const lib_;
201 FunctionLibraryRuntime::Options opts_;
202 TensorVec args_;
203 TensorVec rets_;
204 };
205
GetHandles(OpKernelContext * ctx,FHandle * then_handle,FHandle * else_handle)206 Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
207 FHandle* else_handle) {
208 // TODO(b/37549631): Because this op has `SetIsStateful()` in its
209 // op registration, this kernel may be shared by multiple
210 // subgraphs, which have different associated
211 // `FunctionLibraryRuntime` objects and hence different `FHandle`
212 // namespaces. We currently work around this by caching the map
213 // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
214 // functions this op uses.
215 auto lib = ctx->function_library();
216 if (lib == nullptr) return errors::Internal("No function library");
217 *then_handle = kInvalidHandle;
218 *else_handle = kInvalidHandle;
219 {
220 tf_shared_lock l(mu_);
221 const auto iter = handles_.find(lib);
222 if (TF_PREDICT_TRUE(iter != handles_.end())) {
223 *then_handle = iter->second.first;
224 *else_handle = iter->second.second;
225 }
226 }
227 if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
228 mutex_lock l(mu_);
229 const auto iter = handles_.find(lib);
230 if (TF_PREDICT_TRUE(iter != handles_.end())) {
231 *then_handle = iter->second.first;
232 *else_handle = iter->second.second;
233 } else {
234 TF_RETURN_IF_ERROR(Instantiate(ctx, then_func_, then_handle));
235 TF_RETURN_IF_ERROR(Instantiate(ctx, else_func_, else_handle));
236 handles_[lib] = {*then_handle, *else_handle};
237 }
238 }
239 return Status::OK();
240 }
241 };
242
243 class CaseOp : public AsyncOpKernel {
244 public:
CaseOp(OpKernelConstruction * ctx)245 explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
246 auto lib = ctx->function_library();
247 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
248 OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
249 }
250
~CaseOp()251 ~CaseOp() override {}
252
ComputeAsync(OpKernelContext * ctx,DoneCallback done)253 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
254 auto lib = ctx->function_library();
255 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
256 errors::Internal("No function library"), done);
257
258 // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
259 // registration, this kernel may be shared by multiple subgraphs, which have
260 // different associated `FunctionLibraryRuntime` objects and hence different
261 // `FHandle` namespaces. So we must call Instantiate() to make sure we get
262 // the correct function handles with respect to `lib`. Note the underlying
263 // `lib->Instantiate()` caches the created function handles, so calling
264 // `Instantiate()` repeatedly on the same `lib` and function is cheap.
265 std::vector<FHandle> branch_handles(branch_funcs_.size());
266 for (int i = 0; i < branch_funcs_.size(); i++) {
267 OP_REQUIRES_OK_ASYNC(
268 ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
269 }
270
271 const Tensor& branch_index = ctx->input(0);
272 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
273 errors::InvalidArgument("branch_index must be scalar"),
274 done);
275 int32 branch = branch_index.scalar<int32>()();
276 (new State(this, ctx, branch, branch_handles, done))->Start();
277 }
278
279 private:
280 std::vector<NameAttrList> branch_funcs_;
281
282 class State {
283 public:
State(CaseOp * kernel,OpKernelContext * ctx,int branch,std::vector<FHandle> branch_handles,DoneCallback done)284 State(CaseOp* kernel, OpKernelContext* ctx, int branch,
285 std::vector<FHandle> branch_handles, DoneCallback done)
286 : kernel_(kernel),
287 ctx_(ctx),
288 branch_(branch),
289 branch_handles_(branch_handles),
290 done_(std::move(done)),
291 lib_(CHECK_NOTNULL(ctx_->function_library())) {
292 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
293 for (int i = 1; i < ctx_->num_inputs(); ++i) {
294 args_.push_back(ctx_->input(i));
295 }
296 }
297
~State()298 ~State() {}
299
Start()300 void Start() {
301 int branch = branch_;
302 // The last branch is the default branch.
303 if (branch < 0 || branch >= branch_handles_.size()) {
304 branch = branch_handles_.size() - 1;
305 }
306 rets_.clear();
307 profiler::TraceMe trace_me("CaseOp");
308 lib_->Run(
309 // Evaluate one of the branch.
310 opts_, branch_handles_[branch], args_, &rets_,
311 // Done callback
312 [this](Status s) {
313 if (s.ok()) {
314 s = SetOutputs(kernel_, ctx_, rets_);
315 }
316 ctx_->SetStatus(s);
317 DoneCallback captured_done(std::move(done_));
318 delete this;
319 captured_done();
320 });
321 }
322
323 private:
324 CaseOp* const kernel_;
325 OpKernelContext* const ctx_;
326 const int branch_;
327 std::vector<FHandle> branch_handles_;
328 DoneCallback done_;
329 FunctionLibraryRuntime* const lib_;
330 FunctionLibraryRuntime::Options opts_;
331 TensorVec args_;
332 TensorVec rets_;
333 };
334 };
335
336 // TODO(drpng): remove this.
337 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
338 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
339 IfOp);
340
341 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
342 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
343
344 REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
345 REGISTER_KERNEL_BUILDER(
346 Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
347 REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
348 REGISTER_KERNEL_BUILDER(
349 Name("StatelessCase").Device(DEVICE_GPU).HostMemory("branch_index"),
350 CaseOp);
351
352 REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
353 REGISTER_KERNEL_BUILDER(
354 Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
355
356 class WhileOp : public AsyncOpKernel {
357 public:
WhileOp(OpKernelConstruction * ctx)358 explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
359 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
360 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
361 }
362
~WhileOp()363 ~WhileOp() override {}
364
ComputeAsync(OpKernelContext * ctx,DoneCallback done)365 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
366 if (ctx->run_all_kernels_inline()) {
367 // Use the non-callback-based implementation when kernels (and function
368 // callbacks) execute inline to avoid stack overflow.
369 OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
370 } else {
371 FHandle cond_handle;
372 FHandle body_handle;
373 OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
374 done);
375 (new State(this, ctx, cond_handle, body_handle, done))->Start();
376 }
377 }
378
Compute(OpKernelContext * ctx)379 void Compute(OpKernelContext* ctx) override {
380 // Use the non-callback-based implementation when the synchronous Compute()
381 // method is invoked, because the caller is explicitly donating a thread.
382 Status s = DoComputeSync(ctx);
383 // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
384 // still an AsyncOpKernel, and there is a run-time check to avoid calling
385 // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
386 // the event of an error).
387 if (TF_PREDICT_FALSE(!s.ok())) {
388 ctx->SetStatus(s);
389 }
390 }
391
392 private:
393 NameAttrList cond_func_;
394 NameAttrList body_func_;
395
396 mutex mu_;
397 std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
398 handles_ ABSL_GUARDED_BY(mu_);
399
CondResultToBool(OpKernelContext * ctx,const FunctionLibraryRuntime::Options & opts,const Tensor & cond_t,bool * out_result)400 static Status CondResultToBool(OpKernelContext* ctx,
401 const FunctionLibraryRuntime::Options& opts,
402 const Tensor& cond_t, bool* out_result) {
403 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
404 const DeviceBase::GpuDeviceInfo* gpu_device_info =
405 ctx->device()->tensorflow_gpu_device_info();
406 const bool is_hostmem_dtype =
407 cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
408 if (!is_hostmem_dtype && gpu_device_info &&
409 (opts.rets_alloc_attrs.empty() ||
410 !opts.rets_alloc_attrs[0].on_host())) {
411 // Copy the ret value to host if it's allocated on device.
412 Device* device = down_cast<Device*>(ctx->device());
413 DeviceContext* device_ctx = ctx->op_device_context();
414 Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
415 TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
416 &cond_t, /*tensor_name=*/"", device, &host_cond_t));
417 return ToBool({host_cond_t}, out_result);
418 }
419 #endif
420 return ToBool({cond_t}, out_result);
421 }
422
423 // The initial loop variable args are the inputs to the kernel.
424 //
425 // We attempt to forward the input so that it can be consumed inside the
426 // body function (and participate in buffer forwarding, etc.).
GetArgsFromContext(OpKernelContext * ctx,std::vector<Tensor> * out_args,DataTypeVector * out_var_types)427 static void GetArgsFromContext(OpKernelContext* ctx,
428 std::vector<Tensor>* out_args,
429 DataTypeVector* out_var_types) {
430 const int num_loop_vars = ctx->num_inputs();
431 out_args->reserve(num_loop_vars);
432 out_var_types->resize(num_loop_vars);
433 for (int i = 0; i < num_loop_vars; ++i) {
434 const Tensor& input = ctx->input(i);
435 (*out_var_types)[i] = input.dtype();
436 std::unique_ptr<Tensor> maybe_forwarded_input = ctx->forward_input(
437 i, /* output_index= */ OpKernelContext::Params::kNoReservation,
438 input.dtype(), input.shape(), ctx->input_memory_type(i),
439 ctx->input_alloc_attr(i));
440 if (maybe_forwarded_input) {
441 out_args->push_back(std::move(*maybe_forwarded_input));
442 } else {
443 out_args->push_back(input);
444 }
445 }
446 }
447
448 class BodyFuncCallFrame : public CallFrameInterface {
449 public:
BodyFuncCallFrame(std::vector<Tensor> * args,std::vector<Tensor> * retvals,DataTypeSlice ret_types)450 BodyFuncCallFrame(std::vector<Tensor>* args, std::vector<Tensor>* retvals,
451 DataTypeSlice ret_types)
452 : args_(args), retvals_(retvals), ret_types_(ret_types) {}
453
num_args() const454 size_t num_args() const override { return args_->size(); }
num_retvals() const455 size_t num_retvals() const override { return retvals_->size(); }
456
GetArg(int index,const Tensor ** val)457 Status GetArg(int index, const Tensor** val) override {
458 if (index < args_->size()) {
459 *val = &(*args_)[index];
460 return Status::OK();
461 } else {
462 return errors::InvalidArgument("Argument ", index, " is out of range.");
463 }
464 }
465
ConsumeArg(int index,Tensor * val)466 void ConsumeArg(int index, Tensor* val) override {
467 DCHECK_GE(index, 0);
468 DCHECK_LT(index, args_->size());
469 *val = std::move((*args_)[index]);
470 }
CanConsumeArg(int index) const471 bool CanConsumeArg(int index) const override {
472 return index >= 0 && index < args_->size();
473 }
474
SetRetval(int index,const Tensor & val)475 Status SetRetval(int index, const Tensor& val) override {
476 if (TF_PREDICT_FALSE(index < 0)) {
477 return errors::InvalidArgument(
478 "Expected non-negative return value index, but got: ", index, ".");
479 } else if (TF_PREDICT_FALSE(index >= retvals_->size())) {
480 return errors::InvalidArgument("While loop body returned ", index + 1,
481 " arguments. Expected: ", num_retvals(),
482 ".");
483 } else if (TF_PREDICT_FALSE(val.dtype() != ret_types_[index])) {
484 return errors::InvalidArgument("Expected type ",
485 DataTypeString(ret_types_[index]),
486 " for return value ", index, " but got ",
487 DataTypeString(val.dtype()), ".");
488 }
489 (*retvals_)[index] = val;
490 return Status::OK();
491 }
492
493 private:
494 std::vector<Tensor>* const args_; // Not owned.
495 std::vector<Tensor>* const retvals_; // Not owned.
496 DataTypeSlice ret_types_;
497
498 TF_DISALLOW_COPY_AND_ASSIGN(BodyFuncCallFrame);
499 };
500
501 class State {
502 public:
State(WhileOp * kernel,OpKernelContext * ctx,FHandle cond_handle,FHandle body_handle,DoneCallback done)503 State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
504 FHandle body_handle, DoneCallback done)
505 : kernel_(kernel),
506 ctx_(ctx),
507 cond_handle_(cond_handle),
508 body_handle_(body_handle),
509 done_(std::move(done)),
510 lib_(CHECK_NOTNULL(ctx_->function_library())) {
511 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
512 GetArgsFromContext(ctx, &args_, &loop_var_types_);
513 body_frame_ =
514 absl::make_unique<BodyFuncCallFrame>(&args_, &rets_, loop_var_types_);
515 }
516
~State()517 ~State() {}
518
Start()519 void Start() { EvalCond(); }
520
521 private:
522 WhileOp* const kernel_;
523 OpKernelContext* const ctx_;
524 const FHandle cond_handle_;
525 const FHandle body_handle_;
526 const DoneCallback done_;
527 FunctionLibraryRuntime* const lib_;
528 FunctionLibraryRuntime::Options opts_;
529 TensorVec args_;
530 TensorVec rets_;
531 DataTypeVector loop_var_types_;
532 std::unique_ptr<BodyFuncCallFrame> body_frame_;
533
EvalCond()534 void EvalCond() {
535 profiler::TraceMe trace_me("WhileOp-EvalCond");
536 lib_->Run(
537 // Evaluate the condition.
538 opts_, cond_handle_, args_, &rets_,
539 // Done cb.
540 [this](const Status& s) {
541 if (!s.ok()) {
542 return Finish(s);
543 }
544 StartBody();
545 });
546 }
547
StartBody()548 void StartBody() {
549 Status s;
550 if (rets_.size() != 1) {
551 s = errors::InvalidArgument(
552 "Expected a single scalar return value from WhileOp cond, got ",
553 rets_.size(), " tensors.");
554 return Finish(s);
555 }
556
557 if (!s.ok()) {
558 return Finish(s);
559 }
560 bool cond;
561 s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
562 if (!s.ok()) {
563 return Finish(s);
564 }
565
566 if (!cond) {
567 return Finish(Status::OK());
568 }
569 rets_.clear();
570 rets_.resize(args_.size());
571 profiler::TraceMe trace_me("WhileOp-StartBody");
572 lib_->Run(
573 // Evaluate the body.
574 opts_, body_handle_, body_frame_.get(),
575 // Done callback
576 [this](const Status& s) {
577 if (!s.ok()) {
578 return Finish(s);
579 }
580 if (args_.size() != rets_.size()) {
581 return Finish(errors::InvalidArgument(
582 "While loop body returned ", rets_.size(),
583 " arguments. Expected: ", args_.size()));
584 }
585 args_.clear();
586 using std::swap;
587 swap(args_, rets_);
588 EvalCond();
589 });
590 }
591
Finish(Status s)592 void Finish(Status s) {
593 if (s.ok()) {
594 s = SetOutputs(kernel_, ctx_, args_);
595 }
596 ctx_->SetStatus(s);
597 done_();
598 delete this;
599 }
600 };
601
DoComputeSync(OpKernelContext * ctx)602 Status DoComputeSync(OpKernelContext* ctx) {
603 FHandle cond_handle;
604 FHandle body_handle;
605 TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
606 auto lib = ctx->function_library();
607 FunctionLibraryRuntime::Options opts;
608 SetRunOptions(ctx, &opts, false /* always_collect_stats */);
609
610 // Pre-allocate argument and return value vectors for the cond and body
611 // functions.
612 std::vector<Tensor> args;
613 const int num_loop_vars = ctx->num_inputs();
614 DataTypeVector loop_var_types(num_loop_vars);
615 GetArgsFromContext(ctx, &args, &loop_var_types);
616 std::vector<Tensor> cond_rets;
617 cond_rets.reserve(1);
618 std::vector<Tensor> body_rets;
619 body_rets.reserve(num_loop_vars);
620
621 // Implement the logic of the while loop as a single C++ do-while loop that
622 // executes the cond and body functions synchronously.
623 do {
624 // Evaluate the cond function on the current loop variables.
625 {
626 profiler::TraceMe trace_me("WhileOp-EvalCond");
627 TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
628 }
629 if (cond_rets.size() != 1) {
630 return errors::InvalidArgument(
631 "Expected a single scalar return value from WhileOp cond, got ",
632 cond_rets.size(), " tensors.");
633 }
634
635 // If the cond function evaluates to false, we are done: output the
636 // current loop variables.
637 bool cond_result;
638 TF_RETURN_IF_ERROR(
639 CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
640 if (!cond_result) {
641 return SetOutputs(this, ctx, args);
642 }
643
644 // Evaluate the body function on the current loop variables, to get an
645 // updated vector of loop variables.
646 {
647 profiler::TraceMe trace_me("WhileOp-StartBody");
648 body_rets.resize(num_loop_vars);
649 BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
650 TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
651 }
652 std::swap(body_rets, args);
653 body_rets.clear();
654 } while (true);
655 }
656
GetHandles(OpKernelContext * ctx,FHandle * cond_handle,FHandle * body_handle)657 Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
658 FHandle* body_handle) {
659 // TODO(b/37549631): Because this op has `SetIsStateful()` in its
660 // op registration, this kernel may be shared by multiple
661 // subgraphs, which have different associated
662 // `FunctionLibraryRuntime` objects and hence different `FHandle`
663 // namespaces. We currently work around this by caching the map
664 // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
665 // functions this op uses.
666 auto lib = ctx->function_library();
667 if (lib == nullptr) return errors::Internal("No function library");
668 *cond_handle = kInvalidHandle;
669 *body_handle = kInvalidHandle;
670 {
671 tf_shared_lock l(mu_);
672 const auto iter = handles_.find(lib);
673 if (TF_PREDICT_TRUE(iter != handles_.end())) {
674 *cond_handle = iter->second.first;
675 *body_handle = iter->second.second;
676 }
677 }
678 if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
679 mutex_lock l(mu_);
680 const auto iter = handles_.find(lib);
681 if (TF_PREDICT_TRUE(iter != handles_.end())) {
682 *cond_handle = iter->second.first;
683 *body_handle = iter->second.second;
684 } else {
685 TF_RETURN_IF_ERROR(Instantiate(ctx, cond_func_, cond_handle));
686 TF_RETURN_IF_ERROR(Instantiate(ctx, body_func_, body_handle));
687 handles_[lib] = {*cond_handle, *body_handle};
688 }
689 }
690 return Status::OK();
691 }
692 };
693 // TODO(drpng): remove these.
694 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
695 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
696
697 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
698 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
699
700 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
701 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
702
703 class ToBoolOp : public OpKernel {
704 public:
ToBoolOp(OpKernelConstruction * ctx)705 explicit ToBoolOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)706 void Compute(OpKernelContext* ctx) override {
707 bool b;
708 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &b));
709 Tensor* out;
710 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
711 out->scalar<bool>()() = b;
712 }
713 };
714
715 REGISTER_KERNEL_BUILDER(Name("ToBool").Device(DEVICE_CPU), ToBoolOp);
716
GetScalar(OpKernelContext * ctx,int index,int32 * value,const char * label)717 Status GetScalar(OpKernelContext* ctx, int index, int32* value,
718 const char* label) {
719 Tensor t = ctx->input(index);
720 if (!TensorShapeUtils::IsScalar(t.shape())) {
721 return errors::InvalidArgument(label, " must be a scalar, but ",
722 t.shape().DebugString());
723 }
724 *value = t.scalar<int32>()();
725 return Status::OK();
726 }
727
728 class ForOp : public AsyncOpKernel {
729 public:
ForOp(OpKernelConstruction * ctx)730 explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
731 auto lib = ctx->function_library();
732 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
733 const NameAttrList* func;
734 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
735 OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
736 }
737
~ForOp()738 ~ForOp() override {}
739
ComputeAsync(OpKernelContext * ctx,DoneCallback done)740 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
741 (new State(this, ctx, done))->Start();
742 }
743
744 private:
745 FHandle body_handle_;
746
747 class State {
748 public:
State(ForOp * kernel,OpKernelContext * ctx,DoneCallback done)749 State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
750 : kernel_(kernel),
751 ctx_(ctx),
752 done_(std::move(done)),
753 lib_(CHECK_NOTNULL(ctx_->function_library())),
754 args_(1 + ctx_->num_inputs() - 3) {
755 args_[0] = Tensor(DT_INT32, {});
756 iter_ = &args_[0].scalar<int32>()();
757
758 const int32 num_loop_inputs = ctx_->num_inputs() - 3;
759 rets_.reserve(num_loop_inputs);
760 for (int i = 0; i < num_loop_inputs; ++i) {
761 rets_.push_back(ctx_->input(3 + i));
762 }
763 }
764
~State()765 ~State() {}
766
Start()767 void Start() {
768 Status s = StartLoop();
769 if (!s.ok()) Finish(s);
770 }
771
772 private:
773 ForOp* const kernel_;
774 OpKernelContext* const ctx_;
775 const DoneCallback done_;
776 FunctionLibraryRuntime* const lib_;
777 FunctionLibraryRuntime::Options opts_;
778 TensorVec args_;
779 TensorVec rets_;
780
781 int32* iter_; // points to args_[0].
782 int32 limit_;
783 int32 delta_;
784
785 // If an error e is returned, caller must call Finish(e).
786 // If OK is returned, the async loop execution has been started.
StartLoop()787 Status StartLoop() {
788 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
789
790 TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
791 TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
792 TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
793
794 if ((delta_ > 0 && *iter_ <= limit_) ||
795 (delta_ < 0 && *iter_ >= limit_) ||
796 (delta_ == 0 && *iter_ == limit_)) {
797 RunNext();
798 return Status::OK();
799 } else {
800 return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
801 " ", limit_, " ", delta_);
802 }
803 }
804
RunNext()805 void RunNext() {
806 bool done_loop;
807 if (delta_ > 0) {
808 done_loop = *iter_ >= limit_;
809 } else {
810 done_loop = *iter_ <= limit_;
811 }
812 if (done_loop) {
813 Finish(Status::OK());
814 return;
815 }
816
817 if (rets_.size() >= args_.size()) {
818 Finish(errors::InvalidArgument(
819 "For loop body returned ", rets_.size(),
820 " arguments. Expected: ", args_.size() - 1));
821 return;
822 }
823 for (int i = 0; i < rets_.size(); ++i) {
824 args_[1 + i] = std::move(rets_[i]);
825 }
826 rets_.clear();
827 profiler::TraceMe trace_me("ForOp");
828 lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
829 [this](const Status& s) {
830 if (s.ok()) {
831 *iter_ += delta_;
832 RunNext();
833 } else {
834 Finish(s);
835 }
836 });
837 }
838
Finish(Status s)839 void Finish(Status s) {
840 if (s.ok()) {
841 s = SetOutputs(kernel_, ctx_, rets_);
842 }
843 ctx_->SetStatus(s);
844 done_();
845 delete this;
846 }
847 };
848 };
849
850 REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
851 REGISTER_KERNEL_BUILDER(Name("For")
852 .Device(DEVICE_GPU)
853 .HostMemory("start")
854 .HostMemory("limit")
855 .HostMemory("delta"),
856 ForOp);
857
858 // FakeParamOp allocates a tensor with a shape conforming to the expected
859 // output. This is necessary if the value will be stored in a while_loop's
860 // TensorList. The output is otherwise not expected to be consumed by anything
861 // else.
862 class FakeParamOp : public OpKernel {
863 public:
FakeParamOp(OpKernelConstruction * context)864 explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
865 DataType dtype;
866 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
867
868 // Set shape to the specified shape, setting unknown dimensions to empty.
869 // If the specified shape is unknown, leave as an empty shape.
870 TensorShape shape;
871 PartialTensorShape partial_shape;
872 OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
873 if (!partial_shape.unknown_rank()) {
874 for (int64 d : partial_shape.dim_sizes()) {
875 shape.AddDim(d == -1 ? 0 : d);
876 }
877 }
878
879 // Create a persistent tensor that we can repeatedly return to save memory.
880 // TODO(b/119612758): add optimization to prevent sending this across
881 // devices on each Compute() call.
882 OP_REQUIRES_OK(context, context->allocate_persistent(
883 dtype, shape, &value_handle_, nullptr));
884 }
885
Compute(OpKernelContext * context)886 void Compute(OpKernelContext* context) override {
887 context->set_output(0, *value_handle_.AccessTensor(context));
888 }
889
890 private:
891 PersistentTensor value_handle_;
892 };
893
894 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
895 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
896
897 // DeviceIndexOP returns the current device index.
898 class DeviceIndexOp : public OpKernel {
899 public:
DeviceIndexOp(OpKernelConstruction * ctx)900 explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
901 OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
902 }
903
Compute(OpKernelContext * ctx)904 void Compute(OpKernelContext* ctx) override {
905 Tensor* device_name_t;
906 OP_REQUIRES_OK(ctx,
907 ctx->allocate_output(0, TensorShape({}), &device_name_t));
908 DeviceNameUtils::ParsedName parsed_name;
909 int index = device_names_.size();
910 if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
911 parsed_name.has_type) {
912 auto it = absl::c_find(device_names_, parsed_name.type);
913 if (it != device_names_.end()) {
914 index = it - device_names_.begin();
915 }
916 }
917 device_name_t->scalar<int32>()() = index;
918 }
919
920 private:
921 PersistentTensor value_handle_;
922 std::vector<string> device_names_;
923 };
924
925 REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
926 REGISTER_KERNEL_BUILDER(
927 Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
928
929 } // namespace
930 } // namespace tensorflow
931