1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/control_flow_ops.h"
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/platform/macros.h"
23
24 namespace tensorflow {
25
Compute(OpKernelContext * context)26 void SwitchOp::Compute(OpKernelContext* context) {
27 const Tensor& outputPorts = context->input(1);
28 OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
29 errors::InvalidArgument("The second input must be a scalar, "
30 "but it has shape ",
31 outputPorts.shape().DebugString()));
32
33 bool pred = outputPorts.scalar<bool>()();
34 int port = (pred) ? 1 : 0;
35 if (context->input_is_ref(0)) {
36 context->forward_ref_input_to_ref_output(0, port);
37 } else {
38 context->set_output(port, context->input(0));
39 }
40 }
41
Compute(OpKernelContext * context)42 void SwitchNOp::Compute(OpKernelContext* context) {
43 const Tensor& output_index_t = context->input(1);
44 OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()),
45 errors::InvalidArgument("The second input must be a scalar, "
46 "but it has shape ",
47 output_index_t.shape().DebugString()));
48 int output_index = output_index_t.scalar<int>()();
49 if (output_index < 0 || output_index >= num_outputs()) {
50 output_index = num_outputs() - 1;
51 }
52 context->set_output(output_index, context->input(0));
53 }
54
55 REGISTER_KERNEL_BUILDER(
56 Name("Switch").Device(DEVICE_DEFAULT).HostMemory("pred"), SwitchOp);
57 REGISTER_KERNEL_BUILDER(
58 Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp);
59
60 REGISTER_KERNEL_BUILDER(
61 Name("_SwitchN").Device(DEVICE_DEFAULT).HostMemory("output_index"),
62 SwitchNOp);
63
64 #define REGISTER_CPU_SWITCH(type) \
65 REGISTER_KERNEL_BUILDER(Name("Switch") \
66 .Device(DEVICE_CPU) \
67 .HostMemory("pred") \
68 .TypeConstraint<type>("T"), \
69 SwitchOp) \
70 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
71 .Device(DEVICE_CPU) \
72 .HostMemory("output_index") \
73 .TypeConstraint<type>("T"), \
74 SwitchNOp)
75
76 #define REGISTER_CPU_REF_SWITCH(type) \
77 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
78 .Device(DEVICE_CPU) \
79 .HostMemory("pred") \
80 .TypeConstraint<type>("T"), \
81 SwitchOp)
82
83 #define REGISTER_GPU_SWITCH(type) \
84 REGISTER_KERNEL_BUILDER(Name("Switch") \
85 .Device(DEVICE_GPU) \
86 .HostMemory("pred") \
87 .TypeConstraint<type>("T"), \
88 SwitchOp) \
89 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
90 .Device(DEVICE_GPU) \
91 .HostMemory("output_index") \
92 .TypeConstraint<type>("T"), \
93 SwitchNOp)
94
95 #define REGISTER_GPU_REF_SWITCH(type) \
96 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
97 .Device(DEVICE_GPU) \
98 .HostMemory("pred") \
99 .TypeConstraint<type>("T"), \
100 SwitchOp)
101
102 TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
103 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
104 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
105 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
106
107 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
108 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
109 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
110 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
111 TF_CALL_variant(REGISTER_GPU_SWITCH);
112 TF_CALL_bool(REGISTER_GPU_SWITCH);
113 TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
114
115 #undef REGISTER_CPU_SWITCH
116 #undef REGISTER_CPU_REF_SWITCH
117 #undef REGISTER_GPU_SWITCH
118 #undef REGISTER_GPU_REF_SWITCH
119
120 // Special GPU kernels for int32, string & resource handles. Requiring all
121 // inputs and outputs to be in host memory.
122 // TODO(b/25387198): Also enable int32 in device memory.
123 #define REGISTER_GPU_HOST_KERNEL(type) \
124 REGISTER_KERNEL_BUILDER(Name("Switch") \
125 .Device(DEVICE_GPU) \
126 .HostMemory("data") \
127 .HostMemory("pred") \
128 .HostMemory("output_false") \
129 .HostMemory("output_true") \
130 .TypeConstraint<type>("T"), \
131 SwitchOp) \
132 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
133 .Device(DEVICE_GPU) \
134 .HostMemory("data") \
135 .HostMemory("output_index") \
136 .HostMemory("outputs") \
137 .TypeConstraint<type>("T"), \
138 SwitchNOp)
139
140 #define REGISTER_GPU_HOST_REF_KERNEL(type) \
141 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
142 .Device(DEVICE_GPU) \
143 .HostMemory("data") \
144 .HostMemory("pred") \
145 .HostMemory("output_false") \
146 .HostMemory("output_true") \
147 .TypeConstraint<type>("T"), \
148 SwitchOp)
149
150 REGISTER_GPU_HOST_KERNEL(int32);
151 REGISTER_GPU_HOST_REF_KERNEL(int32);
152 REGISTER_GPU_HOST_KERNEL(tstring);
153 REGISTER_GPU_HOST_REF_KERNEL(tstring);
154 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
155
156 #undef REGISTER_GPU_HOST_KERNEL
157 #undef REGISTER_GPU_HOST_REF_KERNEL
158
159
160 class RefSelectOp : public OpKernel {
161 public:
RefSelectOp(OpKernelConstruction * context)162 explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
163 OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
164 }
165
Compute(OpKernelContext * context)166 void Compute(OpKernelContext* context) override {
167 const Tensor& index_tensor = context->input(0);
168 OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
169 errors::InvalidArgument("Index must be a scalar, "
170 "but it has shape ",
171 index_tensor.shape().DebugString()));
172
173 int32_t index = index_tensor.scalar<int32>()();
174
175 OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
176 errors::InvalidArgument("Index must be in the range [0, ",
177 num_ref_inputs_, ") but got ", index));
178 context->forward_ref_input_to_ref_output(index + 1, 0);
179 }
180
IsExpensive()181 bool IsExpensive() override { return false; }
182
~RefSelectOp()183 ~RefSelectOp() override {}
184
185 TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
186
187 private:
188 int num_ref_inputs_;
189 };
190
191 #define REGISTER_CPU_REF_SELECT(type) \
192 REGISTER_KERNEL_BUILDER(Name("RefSelect") \
193 .Device(DEVICE_CPU) \
194 .HostMemory("index") \
195 .TypeConstraint<type>("T"), \
196 RefSelectOp)
197 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
198
199 #undef REGISTER_CPU_REF_SWITCH
200
MergeOp(OpKernelConstruction * context)201 MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
202 const DataType dt = context->input_type(0);
203 const int num_in = context->num_inputs();
204 OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
205 {dt, DT_INT32}));
206 }
207
Compute(OpKernelContext * context)208 void MergeOp::Compute(OpKernelContext* context) {
209 bool input_seen = false;
210 for (int i = 0; i < context->num_inputs(); ++i) {
211 if (context->has_input(i)) {
212 if (input_seen) {
213 context->SetStatus(
214 errors::Internal("Merge can not have more than one valid input."));
215 return;
216 }
217 input_seen = true;
218
219 if (IsRefType(context->input_dtype(i))) {
220 context->forward_ref_input_to_ref_output(i, 0);
221 } else {
222 context->set_output(0, context->input(i));
223 }
224 // The value_index output is typically used only in gradient calculations,
225 // so we can avoid allocating in many inference workloads.
226 if (context->output_required(1)) {
227 Tensor* value_index = nullptr;
228 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
229 &value_index));
230 value_index->scalar<int32>()() = i;
231 }
232 }
233 }
234 }
235
236 REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
237 REGISTER_KERNEL_BUILDER(
238 Name("Merge").Device(DEVICE_DEFAULT).HostMemory("value_index"), MergeOp);
239 REGISTER_KERNEL_BUILDER(
240 Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp);
241 REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
242
243 #define REGISTER_GPU_KERNEL(type) \
244 REGISTER_KERNEL_BUILDER(Name("Merge") \
245 .Device(DEVICE_GPU) \
246 .TypeConstraint<type>("T") \
247 .HostMemory("value_index"), \
248 MergeOp);
249
250 #define REGISTER_GPU_REF_KERNEL(type) \
251 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
252 .Device(DEVICE_GPU) \
253 .TypeConstraint<type>("T") \
254 .HostMemory("value_index"), \
255 MergeOp);
256
257 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
258 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
259 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
260 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
261 REGISTER_GPU_KERNEL(bool);
262 REGISTER_GPU_REF_KERNEL(bool);
263 TF_CALL_variant(REGISTER_GPU_KERNEL);
264
265 #undef REGISTER_GPU_KERNEL
266 #undef REGISTER_GPU_REF_KERNEL
267
268
269 // Special GPU kernels for int32 and string.
270 // TODO(b/25387198): Also enable int32 in device memory. This kernel
271 // registration requires all int32 inputs and outputs to be in host memory.
272 #define REGISTER_GPU_HOST_KERNEL(type) \
273 REGISTER_KERNEL_BUILDER(Name("Merge") \
274 .Device(DEVICE_GPU) \
275 .HostMemory("inputs") \
276 .HostMemory("output") \
277 .HostMemory("value_index") \
278 .TypeConstraint<type>("T"), \
279 MergeOp); \
280 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
281 .Device(DEVICE_GPU) \
282 .HostMemory("inputs") \
283 .HostMemory("output") \
284 .HostMemory("value_index") \
285 .TypeConstraint<type>("T"), \
286 MergeOp)
287
288 REGISTER_GPU_HOST_KERNEL(int32);
289 REGISTER_GPU_HOST_KERNEL(tstring);
290 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
291
292 #undef REGISTER_GPU_HOST_KERNEL
293
294
Compute(OpKernelContext * context)295 void EnterOp::Compute(OpKernelContext* context) {
296 if (IsRefType(context->input_dtype(0))) {
297 context->forward_ref_input_to_ref_output(0, 0);
298 } else {
299 context->set_output(0, context->input(0));
300 }
301 }
302
303 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_DEFAULT), EnterOp);
304 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp);
305 REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
306
307 #define REGISTER_GPU_KERNEL(type) \
308 REGISTER_KERNEL_BUILDER( \
309 Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
310 #define REGISTER_GPU_REF_KERNEL(type) \
311 REGISTER_KERNEL_BUILDER( \
312 Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
313
314 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
315 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
316 REGISTER_GPU_KERNEL(bool);
317 REGISTER_GPU_REF_KERNEL(bool);
318 TF_CALL_variant(REGISTER_GPU_KERNEL);
319
320 #undef REGISTER_GPU_KERNEL
321 #undef REGISTER_GPU_REF_KERNEL
322
323
324 // Special GPU kernels for int32 and string.
325 // TODO(b/25387198): Also enable int32 in device memory. This kernel
326 // registration requires all int32 inputs and outputs to be in host memory.
327 #define REGISTER_GPU_HOST_KERNEL(type) \
328 REGISTER_KERNEL_BUILDER(Name("Enter") \
329 .Device(DEVICE_GPU) \
330 .HostMemory("data") \
331 .HostMemory("output") \
332 .TypeConstraint<type>("T"), \
333 EnterOp)
334
335 #define REGISTER_GPU_HOST_REF_KERNEL(type) \
336 REGISTER_KERNEL_BUILDER(Name("RefEnter") \
337 .Device(DEVICE_GPU) \
338 .HostMemory("data") \
339 .HostMemory("output") \
340 .TypeConstraint<type>("T"), \
341 EnterOp)
342
343 REGISTER_GPU_HOST_KERNEL(int32);
344 REGISTER_GPU_HOST_REF_KERNEL(int32);
345 REGISTER_GPU_HOST_KERNEL(tstring);
346 REGISTER_GPU_HOST_REF_KERNEL(tstring);
347 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
348
349 #undef REGISTER_GPU_HOST_KERNEL
350 #undef REGISTER_GPU_HOST_REF_KERNEL
351
Compute(OpKernelContext * context)352 void ExitOp::Compute(OpKernelContext* context) {
353 if (IsRefType(context->input_dtype(0))) {
354 context->forward_ref_input_to_ref_output(0, 0);
355 } else {
356 context->set_output(0, context->input(0));
357 }
358 }
359
360 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_DEFAULT), ExitOp);
361 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp);
362 REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
363
364 #define REGISTER_GPU_KERNEL(type) \
365 REGISTER_KERNEL_BUILDER( \
366 Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
367 #define REGISTER_GPU_REF_KERNEL(type) \
368 REGISTER_KERNEL_BUILDER( \
369 Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
370
371 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
372 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
373 REGISTER_GPU_KERNEL(bool);
374 REGISTER_GPU_REF_KERNEL(bool);
375 TF_CALL_variant(REGISTER_GPU_KERNEL);
376
377 #undef REGISTER_GPU_KERNEL
378 #undef REGISTER_GPU_REF_KERNEL
379
380
381 // Special GPU kernels for int32 and string.
382 // TODO(b/25387198): Also enable int32 in device memory. This kernel
383 // registration requires all int32 inputs and outputs to be in host memory.
384 #define REGISTER_GPU_HOST_KERNEL(type) \
385 REGISTER_KERNEL_BUILDER(Name("Exit") \
386 .Device(DEVICE_GPU) \
387 .HostMemory("data") \
388 .HostMemory("output") \
389 .TypeConstraint<type>("T"), \
390 ExitOp); \
391 REGISTER_KERNEL_BUILDER(Name("RefExit") \
392 .Device(DEVICE_GPU) \
393 .HostMemory("data") \
394 .HostMemory("output") \
395 .TypeConstraint<type>("T"), \
396 ExitOp)
397
398 REGISTER_GPU_HOST_KERNEL(int32);
399 REGISTER_GPU_HOST_KERNEL(tstring);
400 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
401
402 #undef REGISTER_GPU_HOST_KERNEL
403
Compute(OpKernelContext * context)404 void NextIterationOp::Compute(OpKernelContext* context) {
405 if (IsRefType(context->input_dtype(0))) {
406 context->forward_ref_input_to_ref_output(0, 0);
407 } else {
408 context->set_output(0, context->input(0));
409 }
410 }
411
412 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_DEFAULT),
413 NextIterationOp);
414 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM),
415 NextIterationOp);
416 REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
417 NextIterationOp);
418
419 #define REGISTER_GPU_KERNEL(type) \
420 REGISTER_KERNEL_BUILDER( \
421 Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
422 NextIterationOp); \
423 REGISTER_KERNEL_BUILDER( \
424 Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
425 NextIterationOp)
426
427 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
428 REGISTER_GPU_KERNEL(bool);
429 TF_CALL_variant(REGISTER_GPU_KERNEL);
430
431 #undef REGISTER_GPU_KERNEL
432
433 // Special GPU kernels for int32 and string.
434 // TODO(b/25387198): Also enable int32 in device memory. This kernel
435 // registration requires all int32 inputs and outputs to be in host memory.
436 #define REGISTER_GPU_HOST_KERNEL(type) \
437 REGISTER_KERNEL_BUILDER(Name("NextIteration") \
438 .Device(DEVICE_GPU) \
439 .HostMemory("data") \
440 .HostMemory("output") \
441 .TypeConstraint<type>("T"), \
442 NextIterationOp); \
443 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
444 .Device(DEVICE_GPU) \
445 .HostMemory("data") \
446 .HostMemory("output") \
447 .TypeConstraint<type>("T"), \
448 NextIterationOp)
449
450 REGISTER_GPU_HOST_KERNEL(int32);
451 REGISTER_GPU_HOST_KERNEL(tstring);
452 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
453
454 #undef REGISTER_GPU_HOST_KERNEL
455
456
LoopCondOp(OpKernelConstruction * context)457 LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
458 LoopCondOp::~LoopCondOp() = default;
459
Compute(OpKernelContext * context)460 void LoopCondOp::Compute(OpKernelContext* context) {
461 CancellationManager* cm = context->cancellation_manager();
462 if (cm != nullptr) {
463 bool already_cancelled = cm->IsCancelled();
464 OP_REQUIRES(context, !already_cancelled,
465 errors::Cancelled("Loop execution was cancelled."));
466 }
467
468 context->set_output(0, context->input(0));
469 }
470
IsExpensive()471 bool LoopCondOp::IsExpensive() { return false; }
472
473 REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
474 REGISTER_KERNEL_BUILDER(Name("LoopCond")
475 .Device(DEVICE_DEFAULT)
476 .HostMemory("input")
477 .HostMemory("output"),
478 LoopCondOp);
479 REGISTER_KERNEL_BUILDER(Name("LoopCond")
480 .Device(DEVICE_TPU_SYSTEM)
481 .HostMemory("input")
482 .HostMemory("output"),
483 LoopCondOp);
484
485 // ControlTrigger kernel
486 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_DEFAULT),
487 ControlTriggerOp);
488 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_TPU_SYSTEM),
489 ControlTriggerOp);
490
491 // When called, abort op will abort the current process. This can be used to
492 // abort remote PSs when needed.
493 class AbortOp : public OpKernel {
494 public:
AbortOp(OpKernelConstruction * context)495 explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
496 OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
497 OP_REQUIRES_OK(
498 context, context->GetAttr("exit_without_error", &exit_without_error_));
499 }
500
Compute(OpKernelContext * context)501 void Compute(OpKernelContext* context) override {
502 if (!exit_without_error_) {
503 LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
504 } else {
505 LOG(WARNING) << "Exiting the process: " << error_msg_;
506 exit(0);
507 }
508 }
509
510 private:
511 string error_msg_;
512 bool exit_without_error_;
513 };
514
515 REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
516
517 } // namespace tensorflow
518