1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17 #include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18
19 namespace tensorflow {
20 namespace internal {
21 // Static routines not in the templated class to reduce code size
ValidateSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & segment_ids)22 Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input,
23 const Tensor& segment_ids) {
24 if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) {
25 return errors::InvalidArgument("input must be at least rank 1");
26 }
27 if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
28 return errors::InvalidArgument("segment_ids should be a vector.");
29 }
30 const int64_t num_indices = segment_ids.NumElements();
31 if (num_indices != input.dim_size(0)) {
32 return errors::InvalidArgument(
33 "segment_ids should be the same size as dimension 0 of"
34 " input.");
35 }
36
37 return Status::OK();
38 }
39
40 // check routines not in the templated class to reduce code size
ValidateUnsortedSegmentReduction(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)41 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
42 OpKernelContext* context,
43 const Tensor& data,
44 const Tensor& segment_ids,
45 const Tensor& num_segments) {
46 if (!TensorShapeUtils::IsScalar(num_segments.shape())) {
47 return errors::InvalidArgument(
48 "num_segments should be a scalar, not shape ",
49 num_segments.shape().DebugString());
50 }
51
52 if (!TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape())) {
53 return errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
54 " does not start with segment_ids.shape = ",
55 segment_ids.shape().DebugString());
56 }
57
58 return Status::OK();
59 }
60
ValidateSparseSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & indices,const Tensor & segment_ids,bool has_num_segments)61 Status ValidateSparseSegmentReduction(OpKernelContext* context,
62 const Tensor& input,
63 const Tensor& indices,
64 const Tensor& segment_ids,
65 bool has_num_segments) {
66 if (has_num_segments) {
67 const Tensor& num_segments_t = context->input(3);
68 if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) {
69 return errors::InvalidArgument(
70 "num_segments should be a scalar, not shape ",
71 num_segments_t.shape().DebugString());
72 }
73 int64_t output_rows = internal::SubtleMustCopy(
74 num_segments_t.dtype() == DT_INT32 ? num_segments_t.scalar<int32>()()
75 : num_segments_t.scalar<int64>()());
76 if (output_rows < 0) {
77 return errors::InvalidArgument("segment ids must be >= 0");
78 }
79 }
80
81 if (!TensorShapeUtils::IsVector(indices.shape())) {
82 return errors::InvalidArgument("indices should be a vector.");
83 }
84
85 if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
86 return errors::InvalidArgument("segment_ids should be a vector.");
87 }
88
89 const int64_t num_indices = indices.NumElements();
90 if (num_indices != segment_ids.NumElements()) {
91 return errors::InvalidArgument(
92 "segment_ids and indices should have same size.");
93 }
94
95 if (input.dims() < 1) {
96 return errors::InvalidArgument("Shape must be at least rank 1");
97 }
98
99 return Status::OK();
100 }
101
102 } // namespace internal
103
104 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
105 default_value) \
106 REGISTER_KERNEL_BUILDER( \
107 Name(name) \
108 .Device(DEVICE_CPU) \
109 .TypeConstraint<type>("T") \
110 .TypeConstraint<index_type>("Tindices"), \
111 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
112
113 #define REGISTER_REAL_CPU_KERNELS(type, index_type) \
114 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
115 type, index_type, 0); \
116 REGISTER_CPU_KERNEL_SEGMENT( \
117 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
118 REGISTER_CPU_KERNEL_SEGMENT( \
119 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
120 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
121 type, index_type, 0); \
122 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
123 type, index_type, 0)
124
125 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \
126 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
127 type, index_type, 0); \
128 REGISTER_CPU_KERNEL_SEGMENT( \
129 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
130 REGISTER_CPU_KERNEL_SEGMENT( \
131 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
132
133 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
134 REGISTER_REAL_CPU_KERNELS(type, int32)
135
136 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
137 REGISTER_COMPLEX_CPU_KERNELS(type, int32)
138
139 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
140 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
141 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
142 #undef REGISTER_CPU_KERNEL_SEGMENT
143 #undef REGISTER_REAL_CPU_KERNELS
144 #undef REGISTER_COMPLEX_CPU_KERNELS
145 #undef REGISTER_REAL_CPU_KERNELS_ALL
146 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
147
148 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
149 #define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
150 name, type, index_type, initial_value_functor, reduction_kernel_functor, \
151 atomic_reduction_kernel_functor) \
152 REGISTER_KERNEL_BUILDER( \
153 Name(name) \
154 .Device(DEVICE_GPU) \
155 .TypeConstraint<type>("T") \
156 .TypeConstraint<index_type>("Tindices"), \
157 SegmentReductionGPUOp< \
158 type, index_type, \
159 functor::SegmentReductionFunctor< \
160 type, index_type, initial_value_functor, \
161 reduction_kernel_functor, atomic_reduction_kernel_functor> >)
162
163 #define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
164 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
165 "SegmentSum", type, index_type, functor::Zero<type>, \
166 functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \
167 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
168 "SegmentProd", type, index_type, functor::One<type>, \
169 functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
170 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
171 "SegmentMin", type, index_type, functor::Highest<type>, \
172 functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \
173 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
174 "SegmentMax", type, index_type, functor::Lowest<type>, \
175 functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
176
177 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
178 REGISTER_GPU_SORTED_KERNELS(type, int32)
179
180 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
181 #undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
182 #undef REGISTER_GPU_SORTED_KERNELS
183 #undef REGISTER_GPU_SORTED_KERNELS_ALL
184 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
185
186 } // namespace tensorflow
187