• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/control_flow_ops.h"
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/platform/macros.h"
23 
24 namespace tensorflow {
25 
Compute(OpKernelContext * context)26 void SwitchOp::Compute(OpKernelContext* context) {
27   const Tensor& outputPorts = context->input(1);
28   OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
29               errors::InvalidArgument("The second input must be a scalar, "
30                                       "but it has shape ",
31                                       outputPorts.shape().DebugString()));
32 
33   bool pred = outputPorts.scalar<bool>()();
34   int port = (pred) ? 1 : 0;
35   if (context->input_is_ref(0)) {
36     context->forward_ref_input_to_ref_output(0, port);
37   } else {
38     context->set_output(port, context->input(0));
39   }
40 }
41 
42 #define REGISTER_CPU_SWITCH(type)                         \
43   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
44                               .Device(DEVICE_CPU)         \
45                               .HostMemory("pred")         \
46                               .TypeConstraint<type>("T"), \
47                           SwitchOp)
48 
49 #define REGISTER_CPU_REF_SWITCH(type)                     \
50   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
51                               .Device(DEVICE_CPU)         \
52                               .HostMemory("pred")         \
53                               .TypeConstraint<type>("T"), \
54                           SwitchOp)
55 
56 #define REGISTER_GPU_SWITCH(type)                         \
57   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
58                               .Device(DEVICE_GPU)         \
59                               .HostMemory("pred")         \
60                               .TypeConstraint<type>("T"), \
61                           SwitchOp)
62 
63 #define REGISTER_GPU_REF_SWITCH(type)                     \
64   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
65                               .Device(DEVICE_GPU)         \
66                               .HostMemory("pred")         \
67                               .TypeConstraint<type>("T"), \
68                           SwitchOp)
69 
70 TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
71 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
72 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
73 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
74 REGISTER_CPU_SWITCH(uint64);
75 
76 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
77 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
78 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
79 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
80 REGISTER_GPU_SWITCH(uint64);
81 TF_CALL_variant(REGISTER_GPU_SWITCH);
82 
83 #undef REGISTER_CPU_SWITCH
84 #undef REGISTER_CPU_REF_SWITCH
85 #undef REGISTER_GPU_SWITCH
86 #undef REGISTER_GPU_REF_SWITCH
87 
88 // Special GPU kernels for int32 and string.
89 // TODO(b/25387198): Also enable int32 in device memory. This kernel
90 // registration requires all int32 inputs and outputs to be in host memory.
91 #define REGISTER_GPU_HOST_KERNEL(type)                    \
92   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
93                               .Device(DEVICE_GPU)         \
94                               .HostMemory("data")         \
95                               .HostMemory("pred")         \
96                               .HostMemory("output_false") \
97                               .HostMemory("output_true")  \
98                               .TypeConstraint<type>("T"), \
99                           SwitchOp)
100 
101 #define REGISTER_GPU_HOST_REF_KERNEL(type)                \
102   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
103                               .Device(DEVICE_GPU)         \
104                               .HostMemory("data")         \
105                               .HostMemory("pred")         \
106                               .HostMemory("output_false") \
107                               .HostMemory("output_true")  \
108                               .TypeConstraint<type>("T"), \
109                           SwitchOp)
110 
111 REGISTER_GPU_HOST_KERNEL(int32);
112 REGISTER_GPU_HOST_REF_KERNEL(int32);
113 REGISTER_GPU_HOST_KERNEL(bool);
114 REGISTER_GPU_HOST_REF_KERNEL(bool);
115 REGISTER_GPU_HOST_KERNEL(string);
116 REGISTER_GPU_HOST_REF_KERNEL(string);
117 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
118 
119 #undef REGISTER_GPU_HOST_KERNEL
120 #undef REGISTER_GPU_HOST_REF_KERNEL
121 
122 #ifdef TENSORFLOW_USE_SYCL
123 #define REGISTER_SYCL_SWITCH(type)                        \
124   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
125                               .Device(DEVICE_SYCL)        \
126                               .HostMemory("pred")         \
127                               .TypeConstraint<type>("T"), \
128                           SwitchOp)
129 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
130 
131 #define REGISTER_SYCL_REF_SWITCH(type)                    \
132   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
133                               .Device(DEVICE_SYCL)        \
134                               .HostMemory("pred")         \
135                               .TypeConstraint<type>("T"), \
136                           SwitchOp)
137 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
138 
139 #undef REGISTER_SYCL_SWITCH
140 #undef REGISTER_SYCL_REF_SWITCH
141 
142 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
143   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
144                               .Device(DEVICE_SYCL)        \
145                               .HostMemory("data")         \
146                               .HostMemory("pred")         \
147                               .HostMemory("output_false") \
148                               .HostMemory("output_true")  \
149                               .TypeConstraint<type>("T"), \
150                           SwitchOp)
151 
152 REGISTER_SYCL_HOST_KERNEL(bool);
153 REGISTER_SYCL_HOST_KERNEL(string);
154 REGISTER_SYCL_HOST_KERNEL(int32);
155 
156 #define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
157   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
158                               .Device(DEVICE_SYCL)        \
159                               .HostMemory("data")         \
160                               .HostMemory("pred")         \
161                               .HostMemory("output_false") \
162                               .HostMemory("output_true")  \
163                               .TypeConstraint<type>("T"), \
164                           SwitchOp)
165 
166 REGISTER_SYCL_HOST_REF_KERNEL(int32);
167 REGISTER_SYCL_HOST_REF_KERNEL(bool);
168 REGISTER_SYCL_HOST_REF_KERNEL(string);
169 
170 #undef REGISTER_SYCL_HOST_KERNEL
171 #undef REGISTER_SYCL_HOST_REF_KERNEL
172 #endif  // TENSORFLOW_USE_SYCL
173 
174 class RefSelectOp : public OpKernel {
175  public:
RefSelectOp(OpKernelConstruction * context)176   explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
177     OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
178   }
179 
Compute(OpKernelContext * context)180   void Compute(OpKernelContext* context) override {
181     const Tensor& index_tensor = context->input(0);
182     OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
183                 errors::InvalidArgument("Index must be a scalar, "
184                                         "but it has shape ",
185                                         index_tensor.shape().DebugString()));
186 
187     int32 index = index_tensor.scalar<int32>()();
188 
189     OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
190                 errors::InvalidArgument("Index must be in the range [0, ",
191                                         num_ref_inputs_, ") but got ", index));
192     context->forward_ref_input_to_ref_output(index + 1, 0);
193   }
194 
IsExpensive()195   bool IsExpensive() override { return false; }
196 
~RefSelectOp()197   ~RefSelectOp() override {}
198 
199   TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
200 
201  private:
202   int num_ref_inputs_;
203 };
204 
205 #define REGISTER_CPU_REF_SELECT(type)                     \
206   REGISTER_KERNEL_BUILDER(Name("RefSelect")               \
207                               .Device(DEVICE_CPU)         \
208                               .HostMemory("index")        \
209                               .TypeConstraint<type>("T"), \
210                           RefSelectOp)
211 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
212 
213 #undef REGISTER_CPU_REF_SWITCH
214 
MergeOp(OpKernelConstruction * context)215 MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
216   const DataType dt = context->input_type(0);
217   const int num_in = context->num_inputs();
218   OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
219                                                   {dt, DT_INT32}));
220 }
221 
Compute(OpKernelContext * context)222 void MergeOp::Compute(OpKernelContext* context) {
223   bool input_seen = false;
224   for (int i = 0; i < context->num_inputs(); ++i) {
225     if (context->has_input(i)) {
226       if (input_seen) {
227         context->SetStatus(
228             errors::Internal("Merge can not have more than one valid input."));
229         return;
230       }
231       input_seen = true;
232 
233       if (IsRefType(context->input_dtype(i))) {
234         context->forward_ref_input_to_ref_output(i, 0);
235       } else {
236         context->set_output(0, context->input(i));
237       }
238       Tensor* value_index = nullptr;
239       OP_REQUIRES_OK(
240           context, context->allocate_output(1, TensorShape({}), &value_index));
241       value_index->scalar<int32>()() = i;
242     }
243   }
244 }
245 
246 REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
247 REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
248 
249 #define REGISTER_GPU_KERNEL(type)                         \
250   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
251                               .Device(DEVICE_GPU)         \
252                               .TypeConstraint<type>("T")  \
253                               .HostMemory("value_index"), \
254                           MergeOp);
255 
256 #define REGISTER_GPU_REF_KERNEL(type)                     \
257   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
258                               .Device(DEVICE_GPU)         \
259                               .TypeConstraint<type>("T")  \
260                               .HostMemory("value_index"), \
261                           MergeOp);
262 
263 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
264 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
265 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
266 TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
267 REGISTER_GPU_KERNEL(bool);
268 REGISTER_GPU_REF_KERNEL(bool);
269 REGISTER_GPU_KERNEL(uint64);
270 TF_CALL_variant(REGISTER_GPU_KERNEL);
271 
272 #undef REGISTER_GPU_KERNEL
273 #undef REGISTER_GPU_REF_KERNEL
274 
275 #ifdef TENSORFLOW_USE_SYCL
276 #define REGISTER_SYCL_KERNEL(type)                        \
277   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
278                               .Device(DEVICE_SYCL)        \
279                               .TypeConstraint<type>("T")  \
280                               .HostMemory("value_index"), \
281                           MergeOp);
282 REGISTER_SYCL_KERNEL(bool);
283 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
284 
285 #define REGISTER_SYCL_REF_KERNEL(type)                    \
286   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
287                               .Device(DEVICE_SYCL)        \
288                               .TypeConstraint<type>("T")  \
289                               .HostMemory("value_index"), \
290                           MergeOp);
291 REGISTER_SYCL_REF_KERNEL(bool);
292 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
293 
294 #undef REGISTER_SYCL_KERNEL
295 #undef REGISTER_SYCL_REF_KERNEL
296 #endif  // TENSORFLOW_USE_SYCL
297 
298 // Special GPU kernels for int32 and string.
299 // TODO(b/25387198): Also enable int32 in device memory. This kernel
300 // registration requires all int32 inputs and outputs to be in host memory.
301 #define REGISTER_GPU_HOST_KERNEL(type)                    \
302   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
303                               .Device(DEVICE_GPU)         \
304                               .HostMemory("inputs")       \
305                               .HostMemory("output")       \
306                               .HostMemory("value_index")  \
307                               .TypeConstraint<type>("T"), \
308                           MergeOp);                       \
309   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
310                               .Device(DEVICE_GPU)         \
311                               .HostMemory("inputs")       \
312                               .HostMemory("output")       \
313                               .HostMemory("value_index")  \
314                               .TypeConstraint<type>("T"), \
315                           MergeOp)
316 
317 REGISTER_GPU_HOST_KERNEL(int32);
318 REGISTER_GPU_HOST_KERNEL(string);
319 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
320 
321 #undef REGISTER_GPU_HOST_KERNEL
322 
323 #ifdef TENSORFLOW_USE_SYCL
324 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
325   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
326                               .Device(DEVICE_SYCL)        \
327                               .HostMemory("inputs")       \
328                               .HostMemory("output")       \
329                               .HostMemory("value_index")  \
330                               .TypeConstraint<type>("T"), \
331                           MergeOp);                       \
332   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
333                               .Device(DEVICE_SYCL)        \
334                               .HostMemory("inputs")       \
335                               .HostMemory("output")       \
336                               .HostMemory("value_index")  \
337                               .TypeConstraint<type>("T"), \
338                           MergeOp)
339 
340 REGISTER_SYCL_HOST_KERNEL(int32);
341 REGISTER_SYCL_HOST_KERNEL(string);
342 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
343 
344 #undef REGISTER_SYCL_HOST_KERNEL
345 #endif  // TENSORFLOW_USE_SYCL
346 
Compute(OpKernelContext * context)347 void EnterOp::Compute(OpKernelContext* context) {
348   if (IsRefType(context->input_dtype(0))) {
349     context->forward_ref_input_to_ref_output(0, 0);
350   } else {
351     context->set_output(0, context->input(0));
352   }
353 }
354 
355 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
356 REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
357 
358 #define REGISTER_GPU_KERNEL(type) \
359   REGISTER_KERNEL_BUILDER(        \
360       Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
361 #define REGISTER_GPU_REF_KERNEL(type) \
362   REGISTER_KERNEL_BUILDER(            \
363       Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
364 
365 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
366 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
367 REGISTER_GPU_KERNEL(bool);
368 REGISTER_GPU_REF_KERNEL(bool);
369 TF_CALL_variant(REGISTER_GPU_KERNEL);
370 
371 #undef REGISTER_GPU_KERNEL
372 #undef REGISTER_GPU_REF_KERNEL
373 
374 #ifdef TENSORFLOW_USE_SYCL
375 #define REGISTER_SYCL_KERNEL(type) \
376   REGISTER_KERNEL_BUILDER(         \
377       Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
378 REGISTER_SYCL_KERNEL(bool);
379 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
380 
381 #define REGISTER_SYCL_REF_KERNEL(type) \
382   REGISTER_KERNEL_BUILDER(             \
383       Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
384 REGISTER_SYCL_REF_KERNEL(bool);
385 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
386 
387 #undef REGISTER_SYCL_KERNEL
388 #undef REGISTER_SYCL_REF_KERNEL
389 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
390   REGISTER_KERNEL_BUILDER(Name("Enter")                   \
391                               .Device(DEVICE_SYCL)        \
392                               .HostMemory("data")         \
393                               .HostMemory("output")       \
394                               .TypeConstraint<type>("T"), \
395                           EnterOp)
396 
397 #define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
398   REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
399                               .Device(DEVICE_SYCL)        \
400                               .HostMemory("data")         \
401                               .HostMemory("output")       \
402                               .TypeConstraint<type>("T"), \
403                           EnterOp)
404 
405 REGISTER_SYCL_HOST_KERNEL(int32);
406 REGISTER_SYCL_HOST_REF_KERNEL(int32);
407 REGISTER_SYCL_HOST_KERNEL(string);
408 REGISTER_SYCL_HOST_REF_KERNEL(string);
409 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
410 
411 #undef REGISTER_SYCL_HOST_KERNEL
412 #undef REGISTER_SYCL_HOST_REF_KERNEL
413 #endif  // TENSORFLOW_USE_SYCL
414 
415 // Special GPU kernels for int32 and string.
416 // TODO(b/25387198): Also enable int32 in device memory. This kernel
417 // registration requires all int32 inputs and outputs to be in host memory.
418 #define REGISTER_GPU_HOST_KERNEL(type)                    \
419   REGISTER_KERNEL_BUILDER(Name("Enter")                   \
420                               .Device(DEVICE_GPU)         \
421                               .HostMemory("data")         \
422                               .HostMemory("output")       \
423                               .TypeConstraint<type>("T"), \
424                           EnterOp)
425 
426 #define REGISTER_GPU_HOST_REF_KERNEL(type)                \
427   REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
428                               .Device(DEVICE_GPU)         \
429                               .HostMemory("data")         \
430                               .HostMemory("output")       \
431                               .TypeConstraint<type>("T"), \
432                           EnterOp)
433 
434 REGISTER_GPU_HOST_KERNEL(int32);
435 REGISTER_GPU_HOST_REF_KERNEL(int32);
436 REGISTER_GPU_HOST_KERNEL(string);
437 REGISTER_GPU_HOST_REF_KERNEL(string);
438 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
439 
440 #undef REGISTER_GPU_HOST_KERNEL
441 #undef REGISTER_GPU_HOST_REF_KERNEL
442 
Compute(OpKernelContext * context)443 void ExitOp::Compute(OpKernelContext* context) {
444   if (IsRefType(context->input_dtype(0))) {
445     context->forward_ref_input_to_ref_output(0, 0);
446   } else {
447     context->set_output(0, context->input(0));
448   }
449 }
450 
451 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
452 REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
453 
454 #define REGISTER_GPU_KERNEL(type) \
455   REGISTER_KERNEL_BUILDER(        \
456       Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
457 #define REGISTER_GPU_REF_KERNEL(type) \
458   REGISTER_KERNEL_BUILDER(            \
459       Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
460 
461 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
462 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
463 REGISTER_GPU_KERNEL(bool);
464 REGISTER_GPU_REF_KERNEL(bool);
465 TF_CALL_variant(REGISTER_GPU_KERNEL);
466 
467 #undef REGISTER_GPU_KERNEL
468 #undef REGISTER_GPU_REF_KERNEL
469 
470 #ifdef TENSORFLOW_USE_SYCL
471 #define REGISTER_SYCL_KERNEL(type)                                         \
472   REGISTER_KERNEL_BUILDER(                                                 \
473       Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
474   REGISTER_KERNEL_BUILDER(                                                 \
475       Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
476 REGISTER_SYCL_KERNEL(bool);
477 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
478 
479 #undef REGISTER_SYCL_KERNEL
480 #undef REGISTER_SYCL_REF_KERNEL
481 
482 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
483   REGISTER_KERNEL_BUILDER(Name("Exit")                    \
484                               .Device(DEVICE_SYCL)        \
485                               .HostMemory("data")         \
486                               .HostMemory("output")       \
487                               .TypeConstraint<type>("T"), \
488                           ExitOp);                        \
489   REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
490                               .Device(DEVICE_SYCL)        \
491                               .HostMemory("data")         \
492                               .HostMemory("output")       \
493                               .TypeConstraint<type>("T"), \
494                           ExitOp)
495 
496 REGISTER_SYCL_HOST_KERNEL(int32);
497 REGISTER_SYCL_HOST_KERNEL(string);
498 #undef REGISTER_SYCL_HOST_KERNEL
499 #endif  // TENSORFLOW_USE_SYCL
500 
501 // Special GPU kernels for int32 and string.
502 // TODO(b/25387198): Also enable int32 in device memory. This kernel
503 // registration requires all int32 inputs and outputs to be in host memory.
504 #define REGISTER_GPU_HOST_KERNEL(type)                    \
505   REGISTER_KERNEL_BUILDER(Name("Exit")                    \
506                               .Device(DEVICE_GPU)         \
507                               .HostMemory("data")         \
508                               .HostMemory("output")       \
509                               .TypeConstraint<type>("T"), \
510                           ExitOp);                        \
511   REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
512                               .Device(DEVICE_GPU)         \
513                               .HostMemory("data")         \
514                               .HostMemory("output")       \
515                               .TypeConstraint<type>("T"), \
516                           ExitOp)
517 
518 REGISTER_GPU_HOST_KERNEL(int32);
519 REGISTER_GPU_HOST_KERNEL(string);
520 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
521 
522 #undef REGISTER_GPU_HOST_KERNEL
523 
Compute(OpKernelContext * context)524 void NextIterationOp::Compute(OpKernelContext* context) {
525   if (IsRefType(context->input_dtype(0))) {
526     context->forward_ref_input_to_ref_output(0, 0);
527   } else {
528     context->set_output(0, context->input(0));
529   }
530 }
531 
532 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
533                         NextIterationOp);
534 REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
535                         NextIterationOp);
536 
537 #define REGISTER_GPU_KERNEL(type)                                            \
538   REGISTER_KERNEL_BUILDER(                                                   \
539       Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"),    \
540       NextIterationOp);                                                      \
541   REGISTER_KERNEL_BUILDER(                                                   \
542       Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
543       NextIterationOp)
544 
545 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
546 REGISTER_GPU_KERNEL(bool);
547 TF_CALL_variant(REGISTER_GPU_KERNEL);
548 
549 #undef REGISTER_GPU_KERNEL
550 
551 // Special GPU kernels for int32 and string.
552 // TODO(b/25387198): Also enable int32 in device memory. This kernel
553 // registration requires all int32 inputs and outputs to be in host memory.
554 #define REGISTER_GPU_HOST_KERNEL(type)                    \
555   REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
556                               .Device(DEVICE_GPU)         \
557                               .HostMemory("data")         \
558                               .HostMemory("output")       \
559                               .TypeConstraint<type>("T"), \
560                           NextIterationOp);               \
561   REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
562                               .Device(DEVICE_GPU)         \
563                               .HostMemory("data")         \
564                               .HostMemory("output")       \
565                               .TypeConstraint<type>("T"), \
566                           NextIterationOp)
567 
568 REGISTER_GPU_HOST_KERNEL(int32);
569 REGISTER_GPU_HOST_KERNEL(string);
570 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
571 
572 #undef REGISTER_GPU_HOST_KERNEL
573 
574 #ifdef TENSORFLOW_USE_SYCL
575 #define REGISTER_SYCL_KERNEL(type)                                            \
576   REGISTER_KERNEL_BUILDER(                                                    \
577       Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),    \
578       NextIterationOp);                                                       \
579   REGISTER_KERNEL_BUILDER(                                                    \
580       Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
581       NextIterationOp)
582 REGISTER_SYCL_KERNEL(bool);
583 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
584 
585 #undef REGISTER_SYCL_KERNEL
586 
587 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
588   REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
589                               .Device(DEVICE_SYCL)        \
590                               .HostMemory("data")         \
591                               .HostMemory("output")       \
592                               .TypeConstraint<type>("T"), \
593                           NextIterationOp);               \
594   REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
595                               .Device(DEVICE_SYCL)        \
596                               .HostMemory("data")         \
597                               .HostMemory("output")       \
598                               .TypeConstraint<type>("T"), \
599                           NextIterationOp)
600 
601 REGISTER_SYCL_HOST_KERNEL(int32);
602 REGISTER_SYCL_HOST_KERNEL(string);
603 #undef REGISTER_SYCL_HOST_KERNEL
604 #endif  // TENSORFLOW_USE_SYCL
605 
LoopCondOp(OpKernelConstruction * context)606 LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
607 LoopCondOp::~LoopCondOp() = default;
608 
Compute(OpKernelContext * context)609 void LoopCondOp::Compute(OpKernelContext* context) {
610   CancellationManager* cm = context->cancellation_manager();
611   if (cm != nullptr) {
612     bool already_cancelled = cm->IsCancelled();
613     OP_REQUIRES(context, !already_cancelled,
614                 errors::Cancelled("Loop execution was cancelled."));
615   }
616 
617   context->set_output(0, context->input(0));
618 }
619 
IsExpensive()620 bool LoopCondOp::IsExpensive() { return false; }
621 
622 REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
623 REGISTER_KERNEL_BUILDER(Name("LoopCond")
624                             .Device(DEVICE_GPU)
625                             .HostMemory("input")
626                             .HostMemory("output"),
627                         LoopCondOp);
628 
629 #ifdef TENSORFLOW_USE_SYCL
630 REGISTER_KERNEL_BUILDER(Name("LoopCond")
631                             .Device(DEVICE_SYCL)
632                             .HostMemory("input")
633                             .HostMemory("output"),
634                         LoopCondOp);
635 #endif  // TENSORFLOW_USE_SYCL
636 
637 // ControlTrigger kernels
638 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
639                         ControlTriggerOp);
640 
641 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
642                         ControlTriggerOp);
643 
644 #ifdef TENSORFLOW_USE_SYCL
645 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
646                         ControlTriggerOp);
647 #endif  // TENSORFLOW_USE_SYCL
648 
649 // When called, abort op will abort the current process. This can be used to
650 // abort remote PSs when needed.
651 class AbortOp : public OpKernel {
652  public:
AbortOp(OpKernelConstruction * context)653   explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
654     OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
655     OP_REQUIRES_OK(
656         context, context->GetAttr("exit_without_error", &exit_without_error_));
657   }
658 
Compute(OpKernelContext * context)659   void Compute(OpKernelContext* context) override {
660     if (!exit_without_error_) {
661       LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
662     } else {
663       LOG(WARNING) << "Exiting the process: " << error_msg_;
664       exit(0);
665     }
666   }
667 
668  private:
669   string error_msg_;
670   bool exit_without_error_;
671 };
672 
673 REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
674 
675 }  // namespace tensorflow
676