1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/control_flow_ops.h"
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/platform/macros.h"
23
24 namespace tensorflow {
25
Compute(OpKernelContext * context)26 void SwitchOp::Compute(OpKernelContext* context) {
27 const Tensor& outputPorts = context->input(1);
28 OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
29 errors::InvalidArgument("The second input must be a scalar, "
30 "but it has shape ",
31 outputPorts.shape().DebugString()));
32
33 bool pred = outputPorts.scalar<bool>()();
34 int port = (pred) ? 1 : 0;
35 if (context->input_is_ref(0)) {
36 context->forward_ref_input_to_ref_output(0, port);
37 } else {
38 context->set_output(port, context->input(0));
39 }
40 }
41
42 #define REGISTER_CPU_SWITCH(type) \
43 REGISTER_KERNEL_BUILDER(Name("Switch") \
44 .Device(DEVICE_CPU) \
45 .HostMemory("pred") \
46 .TypeConstraint<type>("T"), \
47 SwitchOp)
48
49 #define REGISTER_CPU_REF_SWITCH(type) \
50 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
51 .Device(DEVICE_CPU) \
52 .HostMemory("pred") \
53 .TypeConstraint<type>("T"), \
54 SwitchOp)
55
56 #define REGISTER_GPU_SWITCH(type) \
57 REGISTER_KERNEL_BUILDER(Name("Switch") \
58 .Device(DEVICE_GPU) \
59 .HostMemory("pred") \
60 .TypeConstraint<type>("T"), \
61 SwitchOp)
62
63 #define REGISTER_GPU_REF_SWITCH(type) \
64 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
65 .Device(DEVICE_GPU) \
66 .HostMemory("pred") \
67 .TypeConstraint<type>("T"), \
68 SwitchOp)
69
70 TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
71 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
72 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
73 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
74 REGISTER_CPU_SWITCH(uint64);
75
76 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
77 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
78 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
79 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
80 REGISTER_GPU_SWITCH(uint64);
81 TF_CALL_variant(REGISTER_GPU_SWITCH);
82
83 #undef REGISTER_CPU_SWITCH
84 #undef REGISTER_CPU_REF_SWITCH
85 #undef REGISTER_GPU_SWITCH
86 #undef REGISTER_GPU_REF_SWITCH
87
88 // Special GPU kernels for int32 and string.
89 // TODO(b/25387198): Also enable int32 in device memory. This kernel
90 // registration requires all int32 inputs and outputs to be in host memory.
91 #define REGISTER_GPU_HOST_KERNEL(type) \
92 REGISTER_KERNEL_BUILDER(Name("Switch") \
93 .Device(DEVICE_GPU) \
94 .HostMemory("data") \
95 .HostMemory("pred") \
96 .HostMemory("output_false") \
97 .HostMemory("output_true") \
98 .TypeConstraint<type>("T"), \
99 SwitchOp)
100
101 #define REGISTER_GPU_HOST_REF_KERNEL(type) \
102 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
103 .Device(DEVICE_GPU) \
104 .HostMemory("data") \
105 .HostMemory("pred") \
106 .HostMemory("output_false") \
107 .HostMemory("output_true") \
108 .TypeConstraint<type>("T"), \
109 SwitchOp)
110
111 REGISTER_GPU_HOST_KERNEL(int32);
112 REGISTER_GPU_HOST_REF_KERNEL(int32);
113 REGISTER_GPU_HOST_KERNEL(bool);
114 REGISTER_GPU_HOST_REF_KERNEL(bool);
115 REGISTER_GPU_HOST_KERNEL(string);
116 REGISTER_GPU_HOST_REF_KERNEL(string);
117 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
118
119 #undef REGISTER_GPU_HOST_KERNEL
120 #undef REGISTER_GPU_HOST_REF_KERNEL
121
122 #ifdef TENSORFLOW_USE_SYCL
123 #define REGISTER_SYCL_SWITCH(type) \
124 REGISTER_KERNEL_BUILDER(Name("Switch") \
125 .Device(DEVICE_SYCL) \
126 .HostMemory("pred") \
127 .TypeConstraint<type>("T"), \
128 SwitchOp)
129 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
130
131 #define REGISTER_SYCL_REF_SWITCH(type) \
132 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
133 .Device(DEVICE_SYCL) \
134 .HostMemory("pred") \
135 .TypeConstraint<type>("T"), \
136 SwitchOp)
137 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
138
139 #undef REGISTER_SYCL_SWITCH
140 #undef REGISTER_SYCL_REF_SWITCH
141
142 #define REGISTER_SYCL_HOST_KERNEL(type) \
143 REGISTER_KERNEL_BUILDER(Name("Switch") \
144 .Device(DEVICE_SYCL) \
145 .HostMemory("data") \
146 .HostMemory("pred") \
147 .HostMemory("output_false") \
148 .HostMemory("output_true") \
149 .TypeConstraint<type>("T"), \
150 SwitchOp)
151
152 REGISTER_SYCL_HOST_KERNEL(bool);
153 REGISTER_SYCL_HOST_KERNEL(string);
154 REGISTER_SYCL_HOST_KERNEL(int32);
155
156 #define REGISTER_SYCL_HOST_REF_KERNEL(type) \
157 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
158 .Device(DEVICE_SYCL) \
159 .HostMemory("data") \
160 .HostMemory("pred") \
161 .HostMemory("output_false") \
162 .HostMemory("output_true") \
163 .TypeConstraint<type>("T"), \
164 SwitchOp)
165
166 REGISTER_SYCL_HOST_REF_KERNEL(int32);
167 REGISTER_SYCL_HOST_REF_KERNEL(bool);
168 REGISTER_SYCL_HOST_REF_KERNEL(string);
169
170 #undef REGISTER_SYCL_HOST_KERNEL
171 #undef REGISTER_SYCL_HOST_REF_KERNEL
172 #endif // TENSORFLOW_USE_SYCL
173
174 class RefSelectOp : public OpKernel {
175 public:
RefSelectOp(OpKernelConstruction * context)176 explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
177 OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
178 }
179
Compute(OpKernelContext * context)180 void Compute(OpKernelContext* context) override {
181 const Tensor& index_tensor = context->input(0);
182 OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
183 errors::InvalidArgument("Index must be a scalar, "
184 "but it has shape ",
185 index_tensor.shape().DebugString()));
186
187 int32 index = index_tensor.scalar<int32>()();
188
189 OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
190 errors::InvalidArgument("Index must be in the range [0, ",
191 num_ref_inputs_, ") but got ", index));
192 context->forward_ref_input_to_ref_output(index + 1, 0);
193 }
194
IsExpensive()195 bool IsExpensive() override { return false; }
196
~RefSelectOp()197 ~RefSelectOp() override {}
198
199 TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
200
201 private:
202 int num_ref_inputs_;
203 };
204
205 #define REGISTER_CPU_REF_SELECT(type) \
206 REGISTER_KERNEL_BUILDER(Name("RefSelect") \
207 .Device(DEVICE_CPU) \
208 .HostMemory("index") \
209 .TypeConstraint<type>("T"), \
210 RefSelectOp)
211 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
212
213 #undef REGISTER_CPU_REF_SWITCH
214
MergeOp(OpKernelConstruction * context)215 MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
216 const DataType dt = context->input_type(0);
217 const int num_in = context->num_inputs();
218 OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
219 {dt, DT_INT32}));
220 }
221
Compute(OpKernelContext * context)222 void MergeOp::Compute(OpKernelContext* context) {
223 bool input_seen = false;
224 for (int i = 0; i < context->num_inputs(); ++i) {
225 if (context->has_input(i)) {
226 if (input_seen) {
227 context->SetStatus(
228 errors::Internal("Merge can not have more than one valid input."));
229 return;
230 }
231 input_seen = true;
232
233 if (IsRefType(context->input_dtype(i))) {
234 context->forward_ref_input_to_ref_output(i, 0);
235 } else {
236 context->set_output(0, context->input(i));
237 }
238 Tensor* value_index = nullptr;
239 OP_REQUIRES_OK(
240 context, context->allocate_output(1, TensorShape({}), &value_index));
241 value_index->scalar<int32>()() = i;
242 }
243 }
244 }
245
246 REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
247 REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
248
249 #define REGISTER_GPU_KERNEL(type) \
250 REGISTER_KERNEL_BUILDER(Name("Merge") \
251 .Device(DEVICE_GPU) \
252 .TypeConstraint<type>("T") \
253 .HostMemory("value_index"), \
254 MergeOp);
255
256 #define REGISTER_GPU_REF_KERNEL(type) \
257 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
258 .Device(DEVICE_GPU) \
259 .TypeConstraint<type>("T") \
260 .HostMemory("value_index"), \
261 MergeOp);
262
263 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
264 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
265 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
266 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
267 REGISTER_GPU_KERNEL(bool);
268 REGISTER_GPU_REF_KERNEL(bool);
269 REGISTER_GPU_KERNEL(uint64);
270 TF_CALL_variant(REGISTER_GPU_KERNEL);
271
272 #undef REGISTER_GPU_KERNEL
273 #undef REGISTER_GPU_REF_KERNEL
274
275 #ifdef TENSORFLOW_USE_SYCL
276 #define REGISTER_SYCL_KERNEL(type) \
277 REGISTER_KERNEL_BUILDER(Name("Merge") \
278 .Device(DEVICE_SYCL) \
279 .TypeConstraint<type>("T") \
280 .HostMemory("value_index"), \
281 MergeOp);
282 REGISTER_SYCL_KERNEL(bool);
283 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
284
285 #define REGISTER_SYCL_REF_KERNEL(type) \
286 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
287 .Device(DEVICE_SYCL) \
288 .TypeConstraint<type>("T") \
289 .HostMemory("value_index"), \
290 MergeOp);
291 REGISTER_SYCL_REF_KERNEL(bool);
292 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
293
294 #undef REGISTER_SYCL_KERNEL
295 #undef REGISTER_SYCL_REF_KERNEL
296 #endif // TENSORFLOW_USE_SYCL
297
298 // Special GPU kernels for int32 and string.
299 // TODO(b/25387198): Also enable int32 in device memory. This kernel
300 // registration requires all int32 inputs and outputs to be in host memory.
301 #define REGISTER_GPU_HOST_KERNEL(type) \
302 REGISTER_KERNEL_BUILDER(Name("Merge") \
303 .Device(DEVICE_GPU) \
304 .HostMemory("inputs") \
305 .HostMemory("output") \
306 .HostMemory("value_index") \
307 .TypeConstraint<type>("T"), \
308 MergeOp); \
309 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
310 .Device(DEVICE_GPU) \
311 .HostMemory("inputs") \
312 .HostMemory("output") \
313 .HostMemory("value_index") \
314 .TypeConstraint<type>("T"), \
315 MergeOp)
316
317 REGISTER_GPU_HOST_KERNEL(int32);
318 REGISTER_GPU_HOST_KERNEL(string);
319 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
320
321 #undef REGISTER_GPU_HOST_KERNEL
322
323 #ifdef TENSORFLOW_USE_SYCL
324 #define REGISTER_SYCL_HOST_KERNEL(type) \
325 REGISTER_KERNEL_BUILDER(Name("Merge") \
326 .Device(DEVICE_SYCL) \
327 .HostMemory("inputs") \
328 .HostMemory("output") \
329 .HostMemory("value_index") \
330 .TypeConstraint<type>("T"), \
331 MergeOp); \
332 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
333 .Device(DEVICE_SYCL) \
334 .HostMemory("inputs") \
335 .HostMemory("output") \
336 .HostMemory("value_index") \
337 .TypeConstraint<type>("T"), \
338 MergeOp)
339
340 REGISTER_SYCL_HOST_KERNEL(int32);
341 REGISTER_SYCL_HOST_KERNEL(string);
342 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
343
344 #undef REGISTER_SYCL_HOST_KERNEL
345 #endif // TENSORFLOW_USE_SYCL
346
Compute(OpKernelContext * context)347 void EnterOp::Compute(OpKernelContext* context) {
348 if (IsRefType(context->input_dtype(0))) {
349 context->forward_ref_input_to_ref_output(0, 0);
350 } else {
351 context->set_output(0, context->input(0));
352 }
353 }
354
355 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
356 REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
357
358 #define REGISTER_GPU_KERNEL(type) \
359 REGISTER_KERNEL_BUILDER( \
360 Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
361 #define REGISTER_GPU_REF_KERNEL(type) \
362 REGISTER_KERNEL_BUILDER( \
363 Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
364
365 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
366 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
367 REGISTER_GPU_KERNEL(bool);
368 REGISTER_GPU_REF_KERNEL(bool);
369 TF_CALL_variant(REGISTER_GPU_KERNEL);
370
371 #undef REGISTER_GPU_KERNEL
372 #undef REGISTER_GPU_REF_KERNEL
373
374 #ifdef TENSORFLOW_USE_SYCL
375 #define REGISTER_SYCL_KERNEL(type) \
376 REGISTER_KERNEL_BUILDER( \
377 Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
378 REGISTER_SYCL_KERNEL(bool);
379 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
380
381 #define REGISTER_SYCL_REF_KERNEL(type) \
382 REGISTER_KERNEL_BUILDER( \
383 Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
384 REGISTER_SYCL_REF_KERNEL(bool);
385 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
386
387 #undef REGISTER_SYCL_KERNEL
388 #undef REGISTER_SYCL_REF_KERNEL
389 #define REGISTER_SYCL_HOST_KERNEL(type) \
390 REGISTER_KERNEL_BUILDER(Name("Enter") \
391 .Device(DEVICE_SYCL) \
392 .HostMemory("data") \
393 .HostMemory("output") \
394 .TypeConstraint<type>("T"), \
395 EnterOp)
396
397 #define REGISTER_SYCL_HOST_REF_KERNEL(type) \
398 REGISTER_KERNEL_BUILDER(Name("RefEnter") \
399 .Device(DEVICE_SYCL) \
400 .HostMemory("data") \
401 .HostMemory("output") \
402 .TypeConstraint<type>("T"), \
403 EnterOp)
404
405 REGISTER_SYCL_HOST_KERNEL(int32);
406 REGISTER_SYCL_HOST_REF_KERNEL(int32);
407 REGISTER_SYCL_HOST_KERNEL(string);
408 REGISTER_SYCL_HOST_REF_KERNEL(string);
409 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
410
411 #undef REGISTER_SYCL_HOST_KERNEL
412 #undef REGISTER_SYCL_HOST_REF_KERNEL
413 #endif // TENSORFLOW_USE_SYCL
414
415 // Special GPU kernels for int32 and string.
416 // TODO(b/25387198): Also enable int32 in device memory. This kernel
417 // registration requires all int32 inputs and outputs to be in host memory.
418 #define REGISTER_GPU_HOST_KERNEL(type) \
419 REGISTER_KERNEL_BUILDER(Name("Enter") \
420 .Device(DEVICE_GPU) \
421 .HostMemory("data") \
422 .HostMemory("output") \
423 .TypeConstraint<type>("T"), \
424 EnterOp)
425
426 #define REGISTER_GPU_HOST_REF_KERNEL(type) \
427 REGISTER_KERNEL_BUILDER(Name("RefEnter") \
428 .Device(DEVICE_GPU) \
429 .HostMemory("data") \
430 .HostMemory("output") \
431 .TypeConstraint<type>("T"), \
432 EnterOp)
433
434 REGISTER_GPU_HOST_KERNEL(int32);
435 REGISTER_GPU_HOST_REF_KERNEL(int32);
436 REGISTER_GPU_HOST_KERNEL(string);
437 REGISTER_GPU_HOST_REF_KERNEL(string);
438 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
439
440 #undef REGISTER_GPU_HOST_KERNEL
441 #undef REGISTER_GPU_HOST_REF_KERNEL
442
Compute(OpKernelContext * context)443 void ExitOp::Compute(OpKernelContext* context) {
444 if (IsRefType(context->input_dtype(0))) {
445 context->forward_ref_input_to_ref_output(0, 0);
446 } else {
447 context->set_output(0, context->input(0));
448 }
449 }
450
451 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
452 REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
453
454 #define REGISTER_GPU_KERNEL(type) \
455 REGISTER_KERNEL_BUILDER( \
456 Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
457 #define REGISTER_GPU_REF_KERNEL(type) \
458 REGISTER_KERNEL_BUILDER( \
459 Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
460
461 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
462 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
463 REGISTER_GPU_KERNEL(bool);
464 REGISTER_GPU_REF_KERNEL(bool);
465 TF_CALL_variant(REGISTER_GPU_KERNEL);
466
467 #undef REGISTER_GPU_KERNEL
468 #undef REGISTER_GPU_REF_KERNEL
469
470 #ifdef TENSORFLOW_USE_SYCL
471 #define REGISTER_SYCL_KERNEL(type) \
472 REGISTER_KERNEL_BUILDER( \
473 Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
474 REGISTER_KERNEL_BUILDER( \
475 Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
476 REGISTER_SYCL_KERNEL(bool);
477 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
478
479 #undef REGISTER_SYCL_KERNEL
480 #undef REGISTER_SYCL_REF_KERNEL
481
482 #define REGISTER_SYCL_HOST_KERNEL(type) \
483 REGISTER_KERNEL_BUILDER(Name("Exit") \
484 .Device(DEVICE_SYCL) \
485 .HostMemory("data") \
486 .HostMemory("output") \
487 .TypeConstraint<type>("T"), \
488 ExitOp); \
489 REGISTER_KERNEL_BUILDER(Name("RefExit") \
490 .Device(DEVICE_SYCL) \
491 .HostMemory("data") \
492 .HostMemory("output") \
493 .TypeConstraint<type>("T"), \
494 ExitOp)
495
496 REGISTER_SYCL_HOST_KERNEL(int32);
497 REGISTER_SYCL_HOST_KERNEL(string);
498 #undef REGISTER_SYCL_HOST_KERNEL
499 #endif // TENSORFLOW_USE_SYCL
500
501 // Special GPU kernels for int32 and string.
502 // TODO(b/25387198): Also enable int32 in device memory. This kernel
503 // registration requires all int32 inputs and outputs to be in host memory.
504 #define REGISTER_GPU_HOST_KERNEL(type) \
505 REGISTER_KERNEL_BUILDER(Name("Exit") \
506 .Device(DEVICE_GPU) \
507 .HostMemory("data") \
508 .HostMemory("output") \
509 .TypeConstraint<type>("T"), \
510 ExitOp); \
511 REGISTER_KERNEL_BUILDER(Name("RefExit") \
512 .Device(DEVICE_GPU) \
513 .HostMemory("data") \
514 .HostMemory("output") \
515 .TypeConstraint<type>("T"), \
516 ExitOp)
517
518 REGISTER_GPU_HOST_KERNEL(int32);
519 REGISTER_GPU_HOST_KERNEL(string);
520 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
521
522 #undef REGISTER_GPU_HOST_KERNEL
523
Compute(OpKernelContext * context)524 void NextIterationOp::Compute(OpKernelContext* context) {
525 if (IsRefType(context->input_dtype(0))) {
526 context->forward_ref_input_to_ref_output(0, 0);
527 } else {
528 context->set_output(0, context->input(0));
529 }
530 }
531
532 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
533 NextIterationOp);
534 REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
535 NextIterationOp);
536
537 #define REGISTER_GPU_KERNEL(type) \
538 REGISTER_KERNEL_BUILDER( \
539 Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
540 NextIterationOp); \
541 REGISTER_KERNEL_BUILDER( \
542 Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
543 NextIterationOp)
544
545 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
546 REGISTER_GPU_KERNEL(bool);
547 TF_CALL_variant(REGISTER_GPU_KERNEL);
548
549 #undef REGISTER_GPU_KERNEL
550
551 // Special GPU kernels for int32 and string.
552 // TODO(b/25387198): Also enable int32 in device memory. This kernel
553 // registration requires all int32 inputs and outputs to be in host memory.
554 #define REGISTER_GPU_HOST_KERNEL(type) \
555 REGISTER_KERNEL_BUILDER(Name("NextIteration") \
556 .Device(DEVICE_GPU) \
557 .HostMemory("data") \
558 .HostMemory("output") \
559 .TypeConstraint<type>("T"), \
560 NextIterationOp); \
561 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
562 .Device(DEVICE_GPU) \
563 .HostMemory("data") \
564 .HostMemory("output") \
565 .TypeConstraint<type>("T"), \
566 NextIterationOp)
567
568 REGISTER_GPU_HOST_KERNEL(int32);
569 REGISTER_GPU_HOST_KERNEL(string);
570 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
571
572 #undef REGISTER_GPU_HOST_KERNEL
573
574 #ifdef TENSORFLOW_USE_SYCL
575 #define REGISTER_SYCL_KERNEL(type) \
576 REGISTER_KERNEL_BUILDER( \
577 Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
578 NextIterationOp); \
579 REGISTER_KERNEL_BUILDER( \
580 Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
581 NextIterationOp)
582 REGISTER_SYCL_KERNEL(bool);
583 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
584
585 #undef REGISTER_SYCL_KERNEL
586
587 #define REGISTER_SYCL_HOST_KERNEL(type) \
588 REGISTER_KERNEL_BUILDER(Name("NextIteration") \
589 .Device(DEVICE_SYCL) \
590 .HostMemory("data") \
591 .HostMemory("output") \
592 .TypeConstraint<type>("T"), \
593 NextIterationOp); \
594 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
595 .Device(DEVICE_SYCL) \
596 .HostMemory("data") \
597 .HostMemory("output") \
598 .TypeConstraint<type>("T"), \
599 NextIterationOp)
600
601 REGISTER_SYCL_HOST_KERNEL(int32);
602 REGISTER_SYCL_HOST_KERNEL(string);
603 #undef REGISTER_SYCL_HOST_KERNEL
604 #endif // TENSORFLOW_USE_SYCL
605
LoopCondOp(OpKernelConstruction * context)606 LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
607 LoopCondOp::~LoopCondOp() = default;
608
Compute(OpKernelContext * context)609 void LoopCondOp::Compute(OpKernelContext* context) {
610 CancellationManager* cm = context->cancellation_manager();
611 if (cm != nullptr) {
612 bool already_cancelled = cm->IsCancelled();
613 OP_REQUIRES(context, !already_cancelled,
614 errors::Cancelled("Loop execution was cancelled."));
615 }
616
617 context->set_output(0, context->input(0));
618 }
619
IsExpensive()620 bool LoopCondOp::IsExpensive() { return false; }
621
622 REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
623 REGISTER_KERNEL_BUILDER(Name("LoopCond")
624 .Device(DEVICE_GPU)
625 .HostMemory("input")
626 .HostMemory("output"),
627 LoopCondOp);
628
629 #ifdef TENSORFLOW_USE_SYCL
630 REGISTER_KERNEL_BUILDER(Name("LoopCond")
631 .Device(DEVICE_SYCL)
632 .HostMemory("input")
633 .HostMemory("output"),
634 LoopCondOp);
635 #endif // TENSORFLOW_USE_SYCL
636
637 // ControlTrigger kernels
638 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
639 ControlTriggerOp);
640
641 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
642 ControlTriggerOp);
643
644 #ifdef TENSORFLOW_USE_SYCL
645 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
646 ControlTriggerOp);
647 #endif // TENSORFLOW_USE_SYCL
648
649 // When called, abort op will abort the current process. This can be used to
650 // abort remote PSs when needed.
651 class AbortOp : public OpKernel {
652 public:
AbortOp(OpKernelConstruction * context)653 explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
654 OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
655 OP_REQUIRES_OK(
656 context, context->GetAttr("exit_without_error", &exit_without_error_));
657 }
658
Compute(OpKernelContext * context)659 void Compute(OpKernelContext* context) override {
660 if (!exit_without_error_) {
661 LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
662 } else {
663 LOG(WARNING) << "Exiting the process: " << error_msg_;
664 exit(0);
665 }
666 }
667
668 private:
669 string error_msg_;
670 bool exit_without_error_;
671 };
672
673 REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
674
675 } // namespace tensorflow
676