• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 #include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18 
19 namespace tensorflow {
20 namespace internal {
21 // Static routines not in the templated class to reduce code size
ValidateSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & segment_ids)22 Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input,
23                                 const Tensor& segment_ids) {
24   if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) {
25     return errors::InvalidArgument("input must be at least rank 1");
26   }
27   if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
28     return errors::InvalidArgument("segment_ids should be a vector.");
29   }
30   const int64_t num_indices = segment_ids.NumElements();
31   if (num_indices != input.dim_size(0)) {
32     return errors::InvalidArgument(
33         "segment_ids should be the same size as dimension 0 of"
34         " input.");
35   }
36 
37   return Status::OK();
38 }
39 
40 // check routines not in the templated class to reduce code size
ValidateUnsortedSegmentReduction(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)41 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
42                                         OpKernelContext* context,
43                                         const Tensor& data,
44                                         const Tensor& segment_ids,
45                                         const Tensor& num_segments) {
46   if (!TensorShapeUtils::IsScalar(num_segments.shape())) {
47     return errors::InvalidArgument(
48         "num_segments should be a scalar, not shape ",
49         num_segments.shape().DebugString());
50   }
51 
52   if (!TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape())) {
53     return errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
54                                    " does not start with segment_ids.shape = ",
55                                    segment_ids.shape().DebugString());
56   }
57 
58   return Status::OK();
59 }
60 
ValidateSparseSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & indices,const Tensor & segment_ids,bool has_num_segments)61 Status ValidateSparseSegmentReduction(OpKernelContext* context,
62                                       const Tensor& input,
63                                       const Tensor& indices,
64                                       const Tensor& segment_ids,
65                                       bool has_num_segments) {
66   if (has_num_segments) {
67     const Tensor& num_segments_t = context->input(3);
68     if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) {
69       return errors::InvalidArgument(
70           "num_segments should be a scalar, not shape ",
71           num_segments_t.shape().DebugString());
72     }
73     int64_t output_rows = internal::SubtleMustCopy(
74         num_segments_t.dtype() == DT_INT32 ? num_segments_t.scalar<int32>()()
75                                            : num_segments_t.scalar<int64>()());
76     if (output_rows < 0) {
77       return errors::InvalidArgument("segment ids must be >= 0");
78     }
79   }
80 
81   if (!TensorShapeUtils::IsVector(indices.shape())) {
82     return errors::InvalidArgument("indices should be a vector.");
83   }
84 
85   if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
86     return errors::InvalidArgument("segment_ids should be a vector.");
87   }
88 
89   const int64_t num_indices = indices.NumElements();
90   if (num_indices != segment_ids.NumElements()) {
91     return errors::InvalidArgument(
92         "segment_ids and indices should have same size.");
93   }
94 
95   if (input.dims() < 1) {
96     return errors::InvalidArgument("Shape must be at least rank 1");
97   }
98 
99   return Status::OK();
100 }
101 
102 }  // namespace internal
103 
104 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
105                                     default_value)                   \
106   REGISTER_KERNEL_BUILDER(                                           \
107       Name(name)                                                     \
108           .Device(DEVICE_CPU)                                        \
109           .TypeConstraint<type>("T")                                 \
110           .TypeConstraint<index_type>("Tindices"),                   \
111       SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
112 
113 #define REGISTER_REAL_CPU_KERNELS(type, index_type)                            \
114   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
115                               type, index_type, 0);                            \
116   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
117       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
118   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
119       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
120   REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
121                               type, index_type, 0);                            \
122   REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
123                               type, index_type, 0)
124 
125 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type)                         \
126   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
127                               type, index_type, 0);                            \
128   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
129       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
130   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
131       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
132 
133 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
134   REGISTER_REAL_CPU_KERNELS(type, int32)
135 
136 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
137   REGISTER_COMPLEX_CPU_KERNELS(type, int32)
138 
139 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
140 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
141 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
142 #undef REGISTER_CPU_KERNEL_SEGMENT
143 #undef REGISTER_REAL_CPU_KERNELS
144 #undef REGISTER_COMPLEX_CPU_KERNELS
145 #undef REGISTER_REAL_CPU_KERNELS_ALL
146 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
147 
148 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
149 #define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                   \
150     name, type, index_type, initial_value_functor, reduction_kernel_functor, \
151     atomic_reduction_kernel_functor)                                         \
152   REGISTER_KERNEL_BUILDER(                                                   \
153       Name(name)                                                             \
154           .Device(DEVICE_GPU)                                                \
155           .TypeConstraint<type>("T")                                         \
156           .TypeConstraint<index_type>("Tindices"),                           \
157       SegmentReductionGPUOp<                                                 \
158           type, index_type,                                                  \
159           functor::SegmentReductionFunctor<                                  \
160               type, index_type, initial_value_functor,                       \
161               reduction_kernel_functor, atomic_reduction_kernel_functor> >)
162 
163 #define REGISTER_GPU_SORTED_KERNELS(type, index_type)                     \
164   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
165       "SegmentSum", type, index_type, functor::Zero<type>,                \
166       functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>);   \
167   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
168       "SegmentProd", type, index_type, functor::One<type>,                \
169       functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
170   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
171       "SegmentMin", type, index_type, functor::Highest<type>,             \
172       functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>);   \
173   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
174       "SegmentMax", type, index_type, functor::Lowest<type>,              \
175       functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
176 
177 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
178   REGISTER_GPU_SORTED_KERNELS(type, int32)
179 
180 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
181 #undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
182 #undef REGISTER_GPU_SORTED_KERNELS
183 #undef REGISTER_GPU_SORTED_KERNELS_ALL
184 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
185 
186 }  // namespace tensorflow
187