• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 #include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18 
19 namespace tensorflow {
20 
21 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT(                           \
22     name, type, index_type, initial_value_functor, reduction_functor)  \
23   REGISTER_KERNEL_BUILDER(                                             \
24       Name(name)                                                       \
25           .Device(DEVICE_CPU)                                          \
26           .TypeConstraint<type>("T")                                   \
27           .TypeConstraint<index_type>("Tindices"),                     \
28       UnsortedSegmentReductionOp<                                      \
29           type, index_type,                                            \
30           functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
31                                           initial_value_functor,       \
32                                           reduction_functor> >)
33 
34 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type)                   \
35   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
36                                       functor::Zero<type>,                     \
37                                       functor::SumOp<type>);                   \
38   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
39                                       functor::Lowest<type>,                   \
40                                       functor::MaxOp<type>);                   \
41   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
42                                       functor::Highest<type>,                  \
43                                       functor::MinOp<type>);                   \
44   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
45                                       functor::One<type>,                      \
46                                       functor::ProdOp<type>);
47 
48 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type)                \
49   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
50                                       functor::Zero<type>,                     \
51                                       functor::SumOp<type>);                   \
52   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
53                                       functor::One<type>,                      \
54                                       functor::ProdOp<type>)
55 
56 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
57   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32)
58 
59 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
60   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32)
61 
62 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
63 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
64 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
65 
66 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS
67 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
68 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
69 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
70 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
71 
72 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
73 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT(                                 \
74     name, type, index_type, initial_value_functor, reduction_kernel_functor) \
75   REGISTER_KERNEL_BUILDER(                                                   \
76       Name(name)                                                             \
77           .Device(DEVICE_GPU)                                                \
78           .HostMemory("num_segments")                                        \
79           .TypeConstraint<type>("T")                                         \
80           .TypeConstraint<index_type>("Tindices"),                           \
81       UnsortedSegmentReductionOp<                                            \
82           type, index_type,                                                  \
83           functor::UnsortedSegmentFunctor<GPUDevice, type, index_type,       \
84                                           initial_value_functor,             \
85                                           reduction_kernel_functor> >)
86 
87 // sum is the only op that supports all input types currently
88 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
89   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
90                                       functor::Lowest<type>,                   \
91                                       functor::AtomicMaxOpGpu<type>);          \
92   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
93                                       functor::Highest<type>,                  \
94                                       functor::AtomicMinOpGpu<type>);          \
95   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
96                                       functor::One<type>,                      \
97                                       functor::AtomicProdOpGpu<type>);
98 
99 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
100   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
101                                       functor::Zero<type>,                    \
102                                       functor::AtomicSumOpGpu<type>);
103 
104 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
105   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)
106 
107 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
108   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32)
109 
110 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
111 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
112 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
113 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
114 // TODO(rocm): support atomicAdd for complex numbers on ROCm
115 #if GOOGLE_CUDA
116 TF_CALL_COMPLEX_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
117 #endif
118 
119 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
120 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS
121 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS
122 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
123 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
124 
125 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
126 
127 }  // namespace tensorflow
128