• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 #include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18 
19 namespace tensorflow {
20 
21 #define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
22   REGISTER_CPU_SPARSE_KERNELS(type, index_type, int32)                         \
23   REGISTER_CPU_SPARSE_KERNELS(type, index_type, int64)
24 #define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type)       \
25   REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
26   REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64)
27 
28 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type)       \
29   REGISTER_KERNEL_BUILDER(                                                    \
30       Name("SparseSegmentSum")                                                \
31           .Device(DEVICE_CPU)                                                 \
32           .TypeConstraint<type>("T")                                          \
33           .TypeConstraint<index_type>("Tidx")                                 \
34           .TypeConstraint<segment_ids_type>("Tsegmentids"),                   \
35       SparseSegmentReductionSumOp<CPUDevice, type, index_type,                \
36                                   segment_ids_type>);                         \
37   REGISTER_KERNEL_BUILDER(                                                    \
38       Name("SparseSegmentSumWithNumSegments")                                 \
39           .Device(DEVICE_CPU)                                                 \
40           .TypeConstraint<type>("T")                                          \
41           .TypeConstraint<index_type>("Tidx")                                 \
42           .TypeConstraint<segment_ids_type>("Tsegmentids"),                   \
43       SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type, index_type, \
44                                                  segment_ids_type>);
45 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
46 #undef REGISTER_CPU_SPARSE_KERNELS
47 
48 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type)        \
49   REGISTER_KERNEL_BUILDER(                                                     \
50       Name("SparseSegmentMean")                                                \
51           .Device(DEVICE_CPU)                                                  \
52           .TypeConstraint<type>("T")                                           \
53           .TypeConstraint<index_type>("Tidx")                                  \
54           .TypeConstraint<segment_ids_type>("Tsegmentids"),                    \
55       SparseSegmentReductionMeanOp<CPUDevice, type, index_type,                \
56                                    segment_ids_type>);                         \
57   REGISTER_KERNEL_BUILDER(                                                     \
58       Name("SparseSegmentMeanWithNumSegments")                                 \
59           .Device(DEVICE_CPU)                                                  \
60           .TypeConstraint<type>("T")                                           \
61           .TypeConstraint<index_type>("Tidx")                                  \
62           .TypeConstraint<segment_ids_type>("Tsegmentids"),                    \
63       SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type, index_type, \
64                                                   segment_ids_type>);
65 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
66 #undef REGISTER_CPU_SPARSE_KERNELS
67 
68 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
69   REGISTER_KERNEL_BUILDER(                                              \
70       Name("SparseSegmentSqrtN")                                        \
71           .Device(DEVICE_CPU)                                           \
72           .TypeConstraint<type>("T")                                    \
73           .TypeConstraint<index_type>("Tidx")                           \
74           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
75       SparseSegmentReductionSqrtNOp<CPUDevice, type, index_type,        \
76                                     segment_ids_type>);                 \
77   REGISTER_KERNEL_BUILDER(                                              \
78       Name("SparseSegmentSqrtNWithNumSegments")                         \
79           .Device(DEVICE_CPU)                                           \
80           .TypeConstraint<type>("T")                                    \
81           .TypeConstraint<index_type>("Tidx")                           \
82           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
83       SparseSegmentReductionSqrtNWithNumSegmentsOp<                     \
84           CPUDevice, type, index_type, segment_ids_type>);
85 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
86 #undef REGISTER_CPU_SPARSE_KERNELS
87 
88 // TODO(benbarsdell): These kernels are disabled on Windows as a workaround for
89 // a CI build error: "formal parameter with requested alignment of 128 won't be
90 // aligned". The root cause is suspected to be an aligned type (AlignedVector)
91 // being passed to a function by value, possibly inside the CUB library
92 // somewhere, but I have not yet been able to reproduce it in isolation outside
93 // of the GitHub CI.
94 #if GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
95 
96 #define REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
97   REGISTER_GPU_SPARSE_KERNELS(type, index_type, int32)                         \
98   REGISTER_GPU_SPARSE_KERNELS(type, index_type, int64)
99 #define REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type)       \
100   REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
101   REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64)
102 
103 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type)       \
104   REGISTER_KERNEL_BUILDER(                                                    \
105       Name("SparseSegmentSum")                                                \
106           .Device(DEVICE_GPU)                                                 \
107           .TypeConstraint<type>("T")                                          \
108           .TypeConstraint<index_type>("Tidx")                                 \
109           .TypeConstraint<segment_ids_type>("Tsegmentids"),                   \
110       SparseSegmentReductionSumOp<GPUDevice, type, index_type,                \
111                                   segment_ids_type>);                         \
112   REGISTER_KERNEL_BUILDER(                                                    \
113       Name("SparseSegmentSumWithNumSegments")                                 \
114           .Device(DEVICE_GPU)                                                 \
115           .HostMemory("num_segments")                                         \
116           .TypeConstraint<type>("T")                                          \
117           .TypeConstraint<index_type>("Tidx")                                 \
118           .TypeConstraint<segment_ids_type>("Tsegmentids"),                   \
119       SparseSegmentReductionSumWithNumSegmentsOp<GPUDevice, type, index_type, \
120                                                  segment_ids_type>);
121 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
122 #undef REGISTER_GPU_SPARSE_KERNELS
123 
124 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type)        \
125   REGISTER_KERNEL_BUILDER(                                                     \
126       Name("SparseSegmentMean")                                                \
127           .Device(DEVICE_GPU)                                                  \
128           .TypeConstraint<type>("T")                                           \
129           .TypeConstraint<index_type>("Tidx")                                  \
130           .TypeConstraint<segment_ids_type>("Tsegmentids"),                    \
131       SparseSegmentReductionMeanOp<GPUDevice, type, index_type,                \
132                                    segment_ids_type>);                         \
133   REGISTER_KERNEL_BUILDER(                                                     \
134       Name("SparseSegmentMeanWithNumSegments")                                 \
135           .Device(DEVICE_GPU)                                                  \
136           .HostMemory("num_segments")                                          \
137           .TypeConstraint<type>("T")                                           \
138           .TypeConstraint<index_type>("Tidx")                                  \
139           .TypeConstraint<segment_ids_type>("Tsegmentids"),                    \
140       SparseSegmentReductionMeanWithNumSegmentsOp<GPUDevice, type, index_type, \
141                                                   segment_ids_type>);
142 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
143 #undef REGISTER_GPU_SPARSE_KERNELS
144 
145 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
146   REGISTER_KERNEL_BUILDER(                                              \
147       Name("SparseSegmentSqrtN")                                        \
148           .Device(DEVICE_GPU)                                           \
149           .TypeConstraint<type>("T")                                    \
150           .TypeConstraint<index_type>("Tidx")                           \
151           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
152       SparseSegmentReductionSqrtNOp<GPUDevice, type, index_type,        \
153                                     segment_ids_type>);                 \
154   REGISTER_KERNEL_BUILDER(                                              \
155       Name("SparseSegmentSqrtNWithNumSegments")                         \
156           .Device(DEVICE_GPU)                                           \
157           .HostMemory("num_segments")                                   \
158           .TypeConstraint<type>("T")                                    \
159           .TypeConstraint<index_type>("Tidx")                           \
160           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
161       SparseSegmentReductionSqrtNWithNumSegmentsOp<                     \
162           GPUDevice, type, index_type, segment_ids_type>);
163 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
164 #undef REGISTER_GPU_SPARSE_KERNELS
165 
166 #endif  // GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
167 
168 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
169   REGISTER_KERNEL_BUILDER(                                              \
170       Name("SparseSegmentSumGrad")                                      \
171           .Device(DEVICE_CPU)                                           \
172           .TypeConstraint<type>("T")                                    \
173           .TypeConstraint<index_type>("Tidx")                           \
174           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
175       SparseSegmentSumGradOp<CPUDevice, type, index_type, segment_ids_type>);
176 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
177 #undef REGISTER_CPU_SPARSE_KERNELS
178 
179 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
180   REGISTER_KERNEL_BUILDER(                                              \
181       Name("SparseSegmentMeanGrad")                                     \
182           .Device(DEVICE_CPU)                                           \
183           .TypeConstraint<type>("T")                                    \
184           .TypeConstraint<index_type>("Tidx")                           \
185           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
186       SparseSegmentMeanGradOp<CPUDevice, type, index_type, segment_ids_type>);
187 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
188 #undef REGISTER_CPU_SPARSE_KERNELS
189 
190 #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
191   REGISTER_KERNEL_BUILDER(                                              \
192       Name("SparseSegmentSqrtNGrad")                                    \
193           .Device(DEVICE_CPU)                                           \
194           .TypeConstraint<type>("T")                                    \
195           .TypeConstraint<index_type>("Tidx")                           \
196           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
197       SparseSegmentSqrtNGradOp<CPUDevice, type, index_type,             \
198                                segment_ids_type>);
199 TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
200 #undef REGISTER_CPU_SPARSE_KERNELS
201 
202 #undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
203 #undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
204 
205 // TODO(benbarsdell): See comment above.
206 #if GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
207 
208 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
209   REGISTER_KERNEL_BUILDER(                                              \
210       Name("SparseSegmentSumGrad")                                      \
211           .Device(DEVICE_GPU)                                           \
212           .HostMemory("output_dim0")                                    \
213           .TypeConstraint<type>("T")                                    \
214           .TypeConstraint<index_type>("Tidx")                           \
215           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
216       SparseSegmentSumGradOp<GPUDevice, type, index_type, segment_ids_type>);
217 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
218 #undef REGISTER_GPU_SPARSE_KERNELS
219 
220 #if 0  // TODO(b/192086735): Enable once bug is fixed.
221 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
222   REGISTER_KERNEL_BUILDER(                                              \
223       Name("SparseSegmentMeanGrad")                                     \
224           .Device(DEVICE_GPU)                                           \
225           .HostMemory("output_dim0")                                    \
226           .TypeConstraint<type>("T")                                    \
227           .TypeConstraint<index_type>("Tidx")                           \
228           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
229       SparseSegmentMeanGradOp<GPUDevice, type, index_type, segment_ids_type>);
230 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
231 #undef REGISTER_GPU_SPARSE_KERNELS
232 #endif
233 
234 #define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
235   REGISTER_KERNEL_BUILDER(                                              \
236       Name("SparseSegmentSqrtNGrad")                                    \
237           .Device(DEVICE_GPU)                                           \
238           .HostMemory("output_dim0")                                    \
239           .TypeConstraint<type>("T")                                    \
240           .TypeConstraint<index_type>("Tidx")                           \
241           .TypeConstraint<segment_ids_type>("Tsegmentids"),             \
242       SparseSegmentSqrtNGradOp<GPUDevice, type, index_type,             \
243                                segment_ids_type>);
244 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
245 #undef REGISTER_GPU_SPARSE_KERNELS
246 
247 #undef REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
248 #undef REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
249 
250 #endif  // GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
251 
252 }  // namespace tensorflow
253