• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA
22 
23 #include "third_party/eigen3/Eigen/Core"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 
26 #include "tensorflow/core/kernels/segment_reduction_ops.h"
27 #include <vector>
28 
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/numeric_op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/util/util.h"
39 
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/kernels/cuda_solvers.h"
43 #include "tensorflow/core/platform/cuda.h"
44 
45 using stream_executor::cuda::ScopedActivateExecutorContext;
46 #endif  // GOOGLE_CUDA
47 
48 namespace tensorflow {
49 
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef Eigen::GpuDevice GPUDevice;
52 
53 // Static routines not in the templated class to reduce code size
SegmentReductionValidationHelper(OpKernelContext * context,const Tensor & input,const Tensor & segment_ids)54 static void SegmentReductionValidationHelper(OpKernelContext* context,
55                                              const Tensor& input,
56                                              const Tensor& segment_ids) {
57   OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
58               errors::InvalidArgument("segment_ids should be a vector."));
59   const int64 num_indices = segment_ids.NumElements();
60   OP_REQUIRES(context, num_indices == input.dim_size(0),
61               errors::InvalidArgument(
62                   "segment_ids should be the same size as dimension 0 of"
63                   " input."));
64 }
65 
SegmentReductionDoValidation(OpKernelContext * c,const Tensor & input,const Tensor & segment_ids)66 static bool SegmentReductionDoValidation(OpKernelContext* c,
67                                          const Tensor& input,
68                                          const Tensor& segment_ids) {
69   SegmentReductionValidationHelper(c, input, segment_ids);
70   return c->status().ok();
71 }
72 
73 // This operator handles reducing segments along the first dimension.
74 // See core/ops/math_ops.cc for more details.
75 template <typename Device, class T, class Index, typename Reducer,
76           int default_value>
77 class SegmentReductionOp : public OpKernel {
78  public:
SegmentReductionOp(OpKernelConstruction * context)79   explicit SegmentReductionOp(OpKernelConstruction* context)
80       : OpKernel(context) {}
81 
Compute(OpKernelContext * context)82   void Compute(OpKernelContext* context) override {
83     const Tensor& input = context->input(0);
84     const Tensor& segment_ids = context->input(1);
85 
86     if (!SegmentReductionDoValidation(context, input, segment_ids)) {
87       return;
88     }
89 
90     const int64 num_indices = segment_ids.NumElements();
91     auto input_flat = input.flat_outer_dims<T>();
92     const int64 num_col = input_flat.dimension(1);
93 
94     const auto segment_vec = segment_ids.vec<Index>();
95     // Note that the current implementation assumes that segment_vec values are
96     // sorted.
97     const Index output_rows =
98         num_indices > 0
99             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
100             : 0;
101     OP_REQUIRES(context, output_rows >= 0,
102                 errors::InvalidArgument("segment ids must be >= 0"));
103 
104     TensorShape output_shape = input.shape();
105     output_shape.set_dim(0, output_rows);
106 
107     // Note that we do not initialize the output buffer with a default value, so
108     // we need to explicitly set missing indices to the default value.
109     Tensor* output = nullptr;
110     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
111     if (num_indices == 0) return;
112     OP_REQUIRES(context, output_rows > 0,
113                 errors::InvalidArgument("segment ids must be >= 0"));
114     auto output_flat = output->flat_outer_dims<T>();
115 
116 #if !defined(EIGEN_HAS_INDEX_LIST)
117     Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
118     dims_to_reduce[0] = 0;
119 #else
120     Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
121 #endif
122     Index start = 0, end = 1;
123 
124     Index uninitialized_index = 0;  // Index from which the output is not set.
125     Index out_index = internal::SubtleMustCopy(segment_vec(start));
126 
127     // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
128     // across threads.
129     Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
130     while (end <= num_indices) {
131       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
132       // used uninitialized in this function" in the Mac build (since the
133       // compiler isn't smart enough to realize the code is safe).
134       Index next_index = 0;
135       if (end < num_indices) {
136         next_index = internal::SubtleMustCopy(segment_vec(end));
137         if (out_index == next_index) {
138           ++end;
139           continue;
140         }
141         // We have a new segment here.  Verify that the segment ids are growing.
142         OP_REQUIRES(context, out_index < next_index,
143                     errors::InvalidArgument("segment ids are not increasing"));
144       }
145 
146       // Process segment [start, end)
147       const T* in_slice_ptr = &input_flat(start, 0);
148       typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
149                                Eigen::Unaligned>
150           OutT;
151 
152       OP_REQUIRES(
153           context, FastBoundsCheck(out_index, output_rows),
154           errors::InvalidArgument(
155               "Segment id ", out_index, " out of range [0, ", output_rows,
156               "), possibly because 'segment_ids' input is not sorted."));
157 
158       // If there is a gap between two indices, we need to set that gap to the
159       // default value.
160       if (out_index > uninitialized_index) {
161         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
162             out_index - uninitialized_index, num_col);
163         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
164             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
165         gap_slice.setConstant(T(default_value));
166       }
167 
168       T* out_slice_ptr = &output_flat(out_index, 0);
169       OutT out_slice(out_slice_ptr, out_slice_shape);
170       // We don't use out_slice.device(context->eigen_device<Device>)
171       // because these pieces of work are likely to be very small and
172       // the context switching overhead dwarfs any benefit we get from
173       // using another thread to do this work.
174       if (start == end - 1) {
175         typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
176                                  Eigen::Unaligned>
177             InT;
178         InT in_slice(in_slice_ptr, out_slice_shape);
179         out_slice = in_slice;
180       } else {
181         Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
182                                                            num_col);
183         typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
184                                  Eigen::Unaligned>
185             InT;
186         InT in_slice(in_slice_ptr, in_slice_shape);
187 
188         out_slice = in_slice.reduce(dims_to_reduce, Reducer());
189       }
190       if (end >= num_indices) break;
191       start = end;
192       ++end;
193       uninitialized_index = out_index + 1;
194       out_index = next_index;
195     }
196   }
197 };
198 
199 #ifdef GOOGLE_CUDA
200 //  SegmentSumGPUOp is a segment sum operator implemented for GPU only.
201 //  TODO: This implementation of SegmentSumGPUOp is sometimes slower than
202 //  its unsorted counterpart (mostly when problem size is small).
203 //  This is due to the following two main reasons and a cost-effective way
204 //  to resolve these problems is desirable.
205 //  1. Sorted segment sum requires a memory transfer from device to host in
206 //     order to know the size of the output dimension whereas unsorted segment
207 //     sum receives the size of the output dimension as an input parameter.
208 //  2. Sorted segment sum is essentially a tiled version of unsorted segment
209 //     sum and therefore such optimization comes at an inherent cost. However
210 //     such cost may not be justified when the problem size is small. When to
211 //     use the tiled version or the untiled version depends on many factors
212 //     including data alignments, ratio of calculation to memory traffic and
213 //     obviously, the problem sizes.
214 template <class T, class Index>
215 class SegmentSumGPUOp : public AsyncOpKernel {
216  public:
SegmentSumGPUOp(OpKernelConstruction * context)217   explicit SegmentSumGPUOp(OpKernelConstruction* context)
218       : AsyncOpKernel(context) {}
219 
ComputeAsync(OpKernelContext * context,DoneCallback done)220   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
221     const Tensor& input = context->input(0);
222     const Tensor& segment_ids = context->input(1);
223 
224     OP_REQUIRES_ASYNC(
225         context, TensorShapeUtils::IsVector(segment_ids.shape()),
226         errors::InvalidArgument("segment_ids should be a vector."), done);
227 
228     const int64 num_indices = segment_ids.NumElements();
229     OP_REQUIRES_ASYNC(
230         context, num_indices == input.dim_size(0),
231         errors::InvalidArgument(
232             "segment_ids should be the same size as dimension 0 of"
233             " input."),
234         done);
235 
236     if (num_indices == 0) {
237       TensorShape output_shape = input.shape();
238       output_shape.set_dim(0, 0);
239 
240       Tensor* output = nullptr;
241       OP_REQUIRES_OK_ASYNC(
242           context, context->allocate_output(0, output_shape, &output), done);
243       done();
244       return;
245     }
246 
247     se::DeviceMemoryBase output_rows_device(
248         const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
249         (num_indices - 1));
250     ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
251 
252     auto stream = context->op_device_context()->stream();
253     OP_REQUIRES_ASYNC(
254         context,
255         stream
256             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
257                          sizeof(Index))
258             .ok(),
259         errors::Internal(
260             "SegmentSumGPUOp: failed to copy output_rows from device"),
261         done);
262 
263     functor::SegmentSumFunctor<T, Index> functor_;
264     auto create_and_check_output = [context, output_rows_host, &input,
265                                     &segment_ids, &functor_, done]() {
266       // Ensure that within the callback, the proper GPU settings are
267       // configured.
268       auto stream = context->op_device_context()->stream();
269       ScopedActivateExecutorContext scoped_activation{stream->parent()};
270 
271       Index output_rows = *output_rows_host.data();
272       output_rows++;
273       OP_REQUIRES_ASYNC(context, output_rows > 0,
274                         errors::InvalidArgument("segment ids must be >= 0"),
275                         done);
276 
277       TensorShape output_shape = input.shape();
278       output_shape.set_dim(0, output_rows);
279 
280       Tensor* output = nullptr;
281       OP_REQUIRES_OK_ASYNC(
282           context, context->allocate_output(0, output_shape, &output), done);
283 
284       auto output_flat = output->flat_outer_dims<T>();
285       auto data_ptr = input.template flat<T>().data();
286       auto segment_flat = segment_ids.flat<Index>();
287       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
288                segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
289                output_flat);
290 
291       done();
292     };
293 
294     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
295         stream, create_and_check_output);
296   }
297 };
298 #endif  // GOOGLE_CUDA
299 
300 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
301                                     default_value)                   \
302   REGISTER_KERNEL_BUILDER(                                           \
303       Name(name)                                                     \
304           .Device(DEVICE_CPU)                                        \
305           .TypeConstraint<type>("T")                                 \
306           .TypeConstraint<index_type>("Tindices"),                   \
307       SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
308 
309 #define REGISTER_REAL_CPU_KERNELS(type, index_type)                            \
310   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
311                               type, index_type, 0);                            \
312   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
313       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
314   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
315       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
316   REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
317                               type, index_type, 0);                            \
318   REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
319                               type, index_type, 0)
320 
321 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type)                         \
322   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
323                               type, index_type, 0);                            \
324   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
325       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
326   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
327       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
328 
329 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
330   REGISTER_REAL_CPU_KERNELS(type, int32);   \
331   REGISTER_REAL_CPU_KERNELS(type, int64)
332 
333 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
334   REGISTER_COMPLEX_CPU_KERNELS(type, int32);   \
335   REGISTER_COMPLEX_CPU_KERNELS(type, int64)
336 
337 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
338 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
339 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
340 #undef REGISTER_CPU_KERNEL_SEGMENT
341 #undef REGISTER_REAL_CPU_KERNELS
342 #undef REGISTER_COMPLEX_CPU_KERNELS
343 #undef REGISTER_REAL_CPU_KERNELS_ALL
344 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
345 
346 #if GOOGLE_CUDA
347 #define REGISTER_GPU_SORTED_KERNELS(type, index_type)                  \
348   REGISTER_KERNEL_BUILDER(Name("SegmentSum")                           \
349                               .Device(DEVICE_GPU)                      \
350                               .TypeConstraint<type>("T")               \
351                               .TypeConstraint<index_type>("Tindices"), \
352                           SegmentSumGPUOp<type, index_type>)
353 
354 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
355   REGISTER_GPU_SORTED_KERNELS(type, int32);   \
356   REGISTER_GPU_SORTED_KERNELS(type, int64);
357 
358 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
359 #undef REGISTER_GPU_SORTED_KERNELS
360 #undef REGISTER_GPU_SORTED_KERNELS_ALL
361 #endif  // GOOGLE_CUDA
362 
363 // ____________________________________________________________________________
364 // Unsorted segment reduction ops.
365 
366 namespace functor {
367 
368 // The ReductionFunctor implementation for CPU.
369 template <typename T, typename Index, typename InitialValueF,
370           typename ReductionF>
371 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
operator ()tensorflow::functor::UnsortedSegmentFunctor372   void operator()(OpKernelContext* ctx, const Index num_segments,
373                   const TensorShape& segment_ids_shape,
374                   typename TTypes<Index>::ConstFlat segment_ids,
375                   const Index data_size, const T* data,
376                   typename TTypes<T, 2>::Tensor output) {
377     output.setConstant(InitialValueF()());
378     if (data_size == 0) {
379       return;
380     }
381     const int64 N = segment_ids.dimension(0);
382     ReductionF reduction;
383     auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
384     for (int64 i = 0; i < N; ++i) {
385       Index j = internal::SubtleMustCopy(segment_ids(i));
386       if (j < 0) {
387         continue;
388       }
389       OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
390                   errors::InvalidArgument(
391                       "segment_ids", SliceDebugString(segment_ids_shape, i),
392                       " = ", j, " is out of range [0, ", num_segments, ")"));
393       reduction(data_flat.template chip<0>(i), output.template chip<0>(j));
394     }
395   }
396 };
397 
398 template <typename T>
399 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
400 
401 template <typename T>
402 using constMatrixChip =
403     Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
404 
405 // reduction functors
406 template <typename T>
407 struct SumOp {
operator ()tensorflow::functor::SumOp408   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
409     output += data;
410   }
411 };
412 
413 template <typename T>
414 struct MaxOp {
operator ()tensorflow::functor::MaxOp415   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
416     output = data.cwiseMax(output);
417   }
418 };
419 
420 template <typename T>
421 struct MinOp {
operator ()tensorflow::functor::MinOp422   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
423     output = data.cwiseMin(output);
424   }
425 };
426 
427 template <typename T>
428 struct ProdOp {
operator ()tensorflow::functor::ProdOp429   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
430     output *= data;
431   }
432 };
433 }  // namespace functor
434 
435 // Static check routines not in the templated class to reduce code size
UnsortedSegmentReductionValidation(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)436 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel,
437                                                OpKernelContext* context,
438                                                const Tensor& data,
439                                                const Tensor& segment_ids,
440                                                const Tensor& num_segments) {
441   OP_REQUIRES(
442       context, op_kernel->IsLegacyScalar(num_segments.shape()),
443       errors::InvalidArgument("num_segments should be a scalar, not shape ",
444                               num_segments.shape().DebugString()));
445   OP_REQUIRES(
446       context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
447       errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
448                               " does not start with segment_ids.shape = ",
449                               segment_ids.shape().DebugString()));
450 }
451 
UnsortedSegmentReductionDoValidation(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)452 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel,
453                                                  OpKernelContext* context,
454                                                  const Tensor& data,
455                                                  const Tensor& segment_ids,
456                                                  const Tensor& num_segments) {
457   UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids,
458                                      num_segments);
459   return context->status().ok();
460 }
461 
462 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
463 // is the device specific implementation of the reduction. These device
464 // specific implementations are templated themselves with the corresponding
465 // initial value functors and reduction functors.
466 template <typename T, typename Index, typename DeviceReductionFunctor>
467 class UnsortedSegmentReductionOp : public OpKernel {
468  public:
UnsortedSegmentReductionOp(OpKernelConstruction * context)469   explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
470       : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
471 
Compute(OpKernelContext * context)472   void Compute(OpKernelContext* context) override {
473     const Tensor& data = context->input(0);
474     const Tensor& segment_ids = context->input(1);
475     const Tensor& num_segments = context->input(2);
476     if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids,
477                                               num_segments)) {
478       return;
479     }
480     const auto segment_flat = segment_ids.flat<Index>();
481     const Index output_rows =
482         internal::SubtleMustCopy(num_segments.scalar<int32>()());
483     OP_REQUIRES(context, output_rows >= 0,
484                 errors::InvalidArgument("Input num_segments == ", output_rows,
485                                         " must not be negative."));
486     TensorShape output_shape;
487     output_shape.AddDim(output_rows);
488     for (int i = segment_ids.dims(); i < data.dims(); i++) {
489       output_shape.AddDim(data.dim_size(i));
490     }
491     Tensor* output = nullptr;
492     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
493     auto output_flat = output->flat_outer_dims<T>();
494     auto data_ptr = data.template flat<T>().data();
495     reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat,
496                        data.NumElements(), data_ptr, output_flat);
497   }
498 
499  protected:
500   DeviceReductionFunctor reduction_functor_;
501 };
502 
503 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT(                           \
504     name, type, index_type, initial_value_functor, reduction_functor)  \
505   REGISTER_KERNEL_BUILDER(                                             \
506       Name(name)                                                       \
507           .Device(DEVICE_CPU)                                          \
508           .TypeConstraint<type>("T")                                   \
509           .TypeConstraint<index_type>("Tindices"),                     \
510       UnsortedSegmentReductionOp<                                      \
511           type, index_type,                                            \
512           functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
513                                           initial_value_functor,       \
514                                           reduction_functor> >)
515 
516 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type)                   \
517   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
518                                       functor::Zero<type>,                     \
519                                       functor::SumOp<type>);                   \
520   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
521                                       functor::Lowest<type>,                   \
522                                       functor::MaxOp<type>);                   \
523   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
524                                       functor::Highest<type>,                  \
525                                       functor::MinOp<type>);                   \
526   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
527                                       functor::One<type>,                      \
528                                       functor::ProdOp<type>);
529 
530 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type)                \
531   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
532                                       functor::Zero<type>,                     \
533                                       functor::SumOp<type>);                   \
534   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
535                                       functor::One<type>,                      \
536                                       functor::ProdOp<type>)
537 
538 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
539   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32);   \
540   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
541 
542 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
543   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32);   \
544   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
545 
546 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
547 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
548 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
549 
550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS
551 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
552 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
553 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
554 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
555 
556 #if GOOGLE_CUDA
557 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT(                                 \
558     name, type, index_type, initial_value_functor, reduction_kernel_functor) \
559   REGISTER_KERNEL_BUILDER(                                                   \
560       Name(name)                                                             \
561           .Device(DEVICE_GPU)                                                \
562           .HostMemory("num_segments")                                        \
563           .TypeConstraint<type>("T")                                         \
564           .TypeConstraint<index_type>("Tindices"),                           \
565       UnsortedSegmentReductionOp<                                            \
566           type, index_type,                                                  \
567           functor::UnsortedSegmentFunctor<GPUDevice, type, index_type,       \
568                                           initial_value_functor,             \
569                                           reduction_kernel_functor> >)
570 
571 // sum is the only op that supports all input types currently
572 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
573   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
574                                       functor::Lowest<type>,                   \
575                                       functor::MaxOpGpu<type>);                \
576   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
577                                       functor::Highest<type>,                  \
578                                       functor::MinOpGpu<type>);                \
579   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
580                                       functor::One<type>,                      \
581                                       functor::ProdOpGpu<type>);
582 
583 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
584   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
585                                       functor::Zero<type>,                    \
586                                       functor::SumOpGpu<type>);
587 
588 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
589   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32);   \
590   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64);
591 
592 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
593   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32);   \
594   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64);
595 
596 
597 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
598 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
600 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
601 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
602 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
603 
604 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
605 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS
606 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS
607 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
608 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
609 
610 #endif  // GOOGLE_CUDA
611 
612 // ____________________________________________________________________________
613 // Sparse segment reduction ops.
614 
615 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
616 // by two dense tensors, one containing the data, and the other containing
617 // indices into the data.
618 template <typename Device, class T>
619 class SparseSegmentReductionOpBase : public OpKernel {
620  public:
SparseSegmentReductionOpBase(OpKernelConstruction * context,bool is_mean,bool is_sqrtn,bool has_num_segments,T default_value)621   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
622                                         bool is_mean, bool is_sqrtn,
623                                         bool has_num_segments, T default_value)
624       : OpKernel(context),
625         is_mean_(is_mean),
626         is_sqrtn_(is_sqrtn),
627         has_num_segments_(has_num_segments),
628         default_value_(default_value) {}
629 
Compute(OpKernelContext * context)630   void Compute(OpKernelContext* context) override {
631     const Tensor& input = context->input(0);
632     const Tensor& indices = context->input(1);
633     const Tensor& segment_ids = context->input(2);
634 
635     Index output_rows = -1;
636     if (has_num_segments_) {
637       const Tensor& num_segments = context->input(3);
638 
639       OP_REQUIRES(
640           context, num_segments.shape().dims() == 0,
641           errors::InvalidArgument("num_segments should be a scalar, not shape ",
642                                   num_segments.shape().DebugString()));
643       output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
644       OP_REQUIRES(context, output_rows >= 0,
645                   errors::InvalidArgument("segment ids must be >= 0"));
646     }
647 
648     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
649                 errors::InvalidArgument("indices should be a vector."));
650     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
651                 errors::InvalidArgument("segment_ids should be a vector."));
652 
653     const int64 num_indices = indices.NumElements();
654     OP_REQUIRES(context, num_indices == segment_ids.NumElements(),
655                 errors::InvalidArgument(
656                     "segment_ids and indices should have same size."));
657 
658     auto input_flat = input.flat_outer_dims<T>();
659     const int64 num_col = input_flat.dimension(1);
660     const auto indices_vec = indices.vec<Index>();
661     typedef int32 OutputRow;
662     const auto segment_vec = segment_ids.vec<OutputRow>();
663     // Note that the current implementation assumes that segment_vec values are
664     // sorted.
665     const OutputRow last_segment_id_plus_one =
666         num_indices > 0
667             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
668             : 0;
669     if (has_num_segments_) {
670       OP_REQUIRES(
671           context, output_rows >= last_segment_id_plus_one,
672           errors::InvalidArgument("segment ids must be < num_segments"));
673     } else {
674       output_rows = last_segment_id_plus_one;
675     }
676     OP_REQUIRES(context, output_rows >= 0,
677                 errors::InvalidArgument("segment ids must be >= 0"));
678 
679     TensorShape output_shape = input.shape();
680     output_shape.set_dim(0, output_rows);
681 
682     // Note that we do not initialize the output buffer with a default value, so
683     // we need to explicitly set missing indices to the default value.
684     Tensor* output = nullptr;
685     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
686     if (num_indices == 0) {
687       if (output_rows > 0) {
688         output->flat_outer_dims<T>().setConstant(default_value_);
689       }
690       return;
691     }
692     OP_REQUIRES(context, output_rows > 0,
693                 errors::InvalidArgument("segment ids must be >= 0"));
694     auto output_flat = output->flat_outer_dims<T>();
695 
696     int64 start = 0, end = 1;
697     // Index from which the output is not initialized.
698     OutputRow uninitialized_index = 0;
699     OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
700 
701     while (true) {
702       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
703       // used uninitialized in this function" in the Mac build (since the
704       // compiler isn't smart enough to realize the code is safe).
705       OutputRow next_index = 0;
706       if (end < num_indices) {
707         next_index = internal::SubtleMustCopy(segment_vec(end));
708         if (out_index == next_index) {
709           ++end;
710           continue;
711         }
712         // We have a new segment here.  Verify that the segment ids are growing.
713         OP_REQUIRES(context, out_index < next_index,
714                     errors::InvalidArgument("segment ids are not increasing"));
715       }
716 
717       OP_REQUIRES(
718           context, FastBoundsCheck(out_index, output_rows),
719           errors::InvalidArgument(
720               "Segment id ", out_index, " out of range [0, ", output_rows,
721               "), possibly because 'segment_ids' input is not sorted."));
722 
723       // If there is a gap between two indices, we need to set that gap to the
724       // default value.
725       if (out_index > uninitialized_index) {
726         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
727             out_index - uninitialized_index, num_col);
728         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
729             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
730         gap_slice.setConstant(default_value_);
731       }
732 
733       auto out = output_flat.template chip<0>(out_index);
734       const int bad_offset =
735           Reduce(input_flat, indices_vec, start, end - start, out);
736       OP_REQUIRES(context, bad_offset < 0,
737                   errors::InvalidArgument(
738                       "Bad: indices[", start + bad_offset,
739                       "] == ", indices_vec(start + bad_offset),
740                       " out of range [0, ", input_flat.dimension(0), ")"));
741 
742       start = end;
743       ++end;
744       uninitialized_index = out_index + 1;
745       out_index = next_index;
746       if (end > num_indices) break;
747     }
748 
749     // Fill the gap at the end with the default value.
750     if (uninitialized_index < output_rows) {
751       Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
752           output_rows - uninitialized_index, num_col);
753       Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
754           gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
755       gap_slice.setConstant(default_value_);
756     }
757   }
758 
759  private:
760   typedef int32 Index;
761 
Reduce(const typename TTypes<T>::ConstMatrix & input_flat,const typename TTypes<Index>::ConstVec & indices_vec,int64 start,int64 num,Eigen::TensorChippingOp<0,typename TTypes<T>::Matrix> out)762   int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
763                const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
764                int64 num,
765                Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
766 #define INDEX(n, i)                               \
767   const auto index##n = indices_vec(start + (i)); \
768   if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
769 
770 #define L(n) input_flat.template chip<0>(index##n)
771 
772     if (num == 1) {
773       INDEX(0, 0);
774       out = L(0);
775     } else {
776       int64 r = num % 8;
777       T m(1);
778       if (is_mean_ && (num < 10)) {
779         m = T(num);
780       }
781       if (is_sqrtn_ && (num < 10)) {
782         m = T(sqrt(num));
783       }
784       switch (r) {
785         case 2: {
786           INDEX(0, 0);
787           INDEX(1, 1);
788           out = (L(0) + L(1)) / m;
789           break;
790         }
791         case 3: {
792           INDEX(0, 0);
793           INDEX(1, 1);
794           INDEX(2, 2);
795           out = (L(0) + L(1) + L(2)) / m;
796           break;
797         }
798         case 4: {
799           INDEX(0, 0);
800           INDEX(1, 1);
801           INDEX(2, 2);
802           INDEX(3, 3);
803           out = (L(0) + L(1) + L(2) + L(3)) / m;
804           break;
805         }
806         case 5: {
807           INDEX(0, 0);
808           INDEX(1, 1);
809           INDEX(2, 2);
810           INDEX(3, 3);
811           INDEX(4, 4);
812           out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
813           break;
814         }
815         case 6: {
816           INDEX(0, 0);
817           INDEX(1, 1);
818           INDEX(2, 2);
819           INDEX(3, 3);
820           INDEX(4, 4);
821           INDEX(5, 5);
822           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
823           break;
824         }
825         case 7: {
826           INDEX(0, 0);
827           INDEX(1, 1);
828           INDEX(2, 2);
829           INDEX(3, 3);
830           INDEX(4, 4);
831           INDEX(5, 5);
832           INDEX(6, 6);
833           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
834           break;
835         }
836         case 0: {
837           INDEX(0, 0);
838           INDEX(1, 1);
839           INDEX(2, 2);
840           INDEX(3, 3);
841           INDEX(4, 4);
842           INDEX(5, 5);
843           INDEX(6, 6);
844           INDEX(7, 7);
845           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
846           r = 8;
847           break;
848         }
849         case 1: {
850           INDEX(0, 0);
851           INDEX(1, 1);
852           INDEX(2, 2);
853           INDEX(3, 3);
854           INDEX(4, 4);
855           INDEX(5, 5);
856           INDEX(6, 6);
857           INDEX(7, 7);
858           INDEX(8, 8);
859           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
860                 m;
861           r = 9;
862           break;
863         }
864       }
865       for (; r < num; r += 8) {
866         INDEX(0, r);
867         INDEX(1, r + 1);
868         INDEX(2, r + 2);
869         INDEX(3, r + 3);
870         INDEX(4, r + 4);
871         INDEX(5, r + 5);
872         INDEX(6, r + 6);
873         INDEX(7, r + 7);
874         out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
875       }
876       if (is_mean_ && num >= 10) {
877         out = out / static_cast<T>(num);
878       }
879       if (is_sqrtn_ && num >= 10) {
880         out = out / static_cast<T>(sqrt(num));
881       }
882     }
883 
884     return -1;
885 #undef L
886 #undef INDEX
887   }
888 
889   const bool is_mean_;
890   const bool is_sqrtn_;
891   const bool has_num_segments_;
892   const T default_value_;
893 };
894 
895 template <typename Device, class T>
896 class SparseSegmentReductionMeanOp
897     : public SparseSegmentReductionOpBase<Device, T> {
898  public:
SparseSegmentReductionMeanOp(OpKernelConstruction * context)899   explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
900       : SparseSegmentReductionOpBase<Device, T>(
901             context, true /*is_mean*/, false /*is_sqrtn*/,
902             false /* has_num_segments */, T(0) /* default_value */) {}
903 };
904 
905 template <typename Device, class T>
906 class SparseSegmentReductionMeanWithNumSegmentsOp
907     : public SparseSegmentReductionOpBase<Device, T> {
908  public:
SparseSegmentReductionMeanWithNumSegmentsOp(OpKernelConstruction * context)909   explicit SparseSegmentReductionMeanWithNumSegmentsOp(
910       OpKernelConstruction* context)
911       : SparseSegmentReductionOpBase<Device, T>(
912             context, true /*is_mean*/, false /*is_sqrtn*/,
913             true /* has_num_segments */, T(0) /* default_value */) {}
914 };
915 
916 template <typename Device, class T>
917 class SparseSegmentReductionSqrtNOp
918     : public SparseSegmentReductionOpBase<Device, T> {
919  public:
SparseSegmentReductionSqrtNOp(OpKernelConstruction * context)920   explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
921       : SparseSegmentReductionOpBase<Device, T>(
922             context, false /*is_mean*/, true /*is_sqrtn*/,
923             false /* has_num_segments */, T(0) /* default_value */) {}
924 };
925 
926 template <typename Device, class T>
927 class SparseSegmentReductionSqrtNWithNumSegmentsOp
928     : public SparseSegmentReductionOpBase<Device, T> {
929  public:
SparseSegmentReductionSqrtNWithNumSegmentsOp(OpKernelConstruction * context)930   explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
931       OpKernelConstruction* context)
932       : SparseSegmentReductionOpBase<Device, T>(
933             context, false /*is_mean*/, true /*is_sqrtn*/,
934             true /* has_num_segments */, T(0) /* default_value */) {}
935 };
936 
937 template <typename Device, class T>
938 class SparseSegmentReductionSumOp
939     : public SparseSegmentReductionOpBase<Device, T> {
940  public:
SparseSegmentReductionSumOp(OpKernelConstruction * context)941   explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
942       : SparseSegmentReductionOpBase<Device, T>(
943             context, false /*is_mean*/, false /*is_sqrtn*/,
944             false /* has_num_segments */, T(0) /* default_value */) {}
945 };
946 
947 template <typename Device, class T>
948 class SparseSegmentReductionSumWithNumSegmentsOp
949     : public SparseSegmentReductionOpBase<Device, T> {
950  public:
SparseSegmentReductionSumWithNumSegmentsOp(OpKernelConstruction * context)951   explicit SparseSegmentReductionSumWithNumSegmentsOp(
952       OpKernelConstruction* context)
953       : SparseSegmentReductionOpBase<Device, T>(
954             context, false /*is_mean*/, false /*is_sqrtn*/,
955             true /* has_num_segments */, T(0) /* default_value */) {}
956 };
957 
958 #define REGISTER_CPU_SPARSE_KERNELS(type)                                \
959   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum")                       \
960                               .Device(DEVICE_CPU)                        \
961                               .TypeConstraint<type>("T")                 \
962                               .TypeConstraint<int32>("Tidx"),            \
963                           SparseSegmentReductionSumOp<CPUDevice, type>); \
964   REGISTER_KERNEL_BUILDER(                                               \
965       Name("SparseSegmentSumWithNumSegments")                            \
966           .Device(DEVICE_CPU)                                            \
967           .TypeConstraint<type>("T")                                     \
968           .TypeConstraint<int32>("Tidx"),                                \
969       SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
970 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
971 #undef REGISTER_CPU_SPARSE_KERNELS
972 
973 #define REGISTER_CPU_SPARSE_KERNELS(type)                                 \
974   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean")                       \
975                               .Device(DEVICE_CPU)                         \
976                               .TypeConstraint<type>("T")                  \
977                               .TypeConstraint<int32>("Tidx"),             \
978                           SparseSegmentReductionMeanOp<CPUDevice, type>); \
979   REGISTER_KERNEL_BUILDER(                                                \
980       Name("SparseSegmentMeanWithNumSegments")                            \
981           .Device(DEVICE_CPU)                                             \
982           .TypeConstraint<type>("T")                                      \
983           .TypeConstraint<int32>("Tidx"),                                 \
984       SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
985 REGISTER_CPU_SPARSE_KERNELS(float);
986 REGISTER_CPU_SPARSE_KERNELS(double);
987 #undef REGISTER_CPU_SPARSE_KERNELS
988 
989 #define REGISTER_CPU_SPARSE_KERNELS(type)                                  \
990   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN")                       \
991                               .Device(DEVICE_CPU)                          \
992                               .TypeConstraint<type>("T")                   \
993                               .TypeConstraint<int32>("Tidx"),              \
994                           SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
995   REGISTER_KERNEL_BUILDER(                                                 \
996       Name("SparseSegmentSqrtNWithNumSegments")                            \
997           .Device(DEVICE_CPU)                                              \
998           .TypeConstraint<type>("T")                                       \
999           .TypeConstraint<int32>("Tidx"),                                  \
1000       SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
1001 REGISTER_CPU_SPARSE_KERNELS(float);
1002 REGISTER_CPU_SPARSE_KERNELS(double);
1003 #undef REGISTER_CPU_SPARSE_KERNELS
1004 
1005 template <class T>
1006 class SparseSegmentGradOpBase : public OpKernel {
1007  public:
SparseSegmentGradOpBase(OpKernelConstruction * context,bool is_sqrtn)1008   explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
1009       : OpKernel(context), is_sqrtn_(is_sqrtn) {}
1010 
Compute(OpKernelContext * context)1011   void Compute(OpKernelContext* context) override {
1012     const Tensor& input = context->input(0);
1013     const Tensor& indices = context->input(1);
1014     const Tensor& segment_ids = context->input(2);
1015     const Tensor& output_dim0 = context->input(3);
1016 
1017     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
1018                 errors::InvalidArgument("indices should be a vector."));
1019     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
1020                 errors::InvalidArgument("segment_ids should be a vector."));
1021     OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
1022                 errors::InvalidArgument("output_dim0 should be a scalar."));
1023 
1024     const int64 N = indices.NumElements();
1025     OP_REQUIRES(context, N == segment_ids.NumElements(),
1026                 errors::InvalidArgument(
1027                     "segment_ids and indices should have same size."));
1028     typedef int32 SegmentId;
1029     const SegmentId M =
1030         internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
1031 
1032     auto input_flat = input.flat_outer_dims<T>();
1033     typedef int32 Index;
1034     const auto indices_vec = indices.vec<Index>();
1035     const auto segment_vec = segment_ids.vec<SegmentId>();
1036 
1037     TensorShape output_shape = input.shape();
1038     output_shape.set_dim(0, M);
1039     Tensor* output = nullptr;
1040     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1041     if (M == 0 || N == 0) return;
1042 
1043     // Note that similar to SparseSegmentMean, we assume that segment_vec is
1044     // already sorted and has non-negative values.
1045     const SegmentId num_segments = input.dim_size(0);
1046     const SegmentId last_segment_id_plus_one =
1047         internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
1048     OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
1049                 errors::InvalidArgument("Invalid number of segments"));
1050 
1051     // Compute scaling factors for input.
1052     std::vector<double> scaling(num_segments, 0.0);
1053     for (int64 i = 0; i < N; ++i) {
1054       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1055       OP_REQUIRES(
1056           context, FastBoundsCheck(idx, num_segments),
1057           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1058                                   num_segments, ")."));
1059       scaling[idx] += 1;
1060     }
1061     for (size_t i = 0; i < scaling.size(); ++i) {
1062       if (is_sqrtn_) {
1063         scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
1064       } else {
1065         scaling[i] = 1.0 / std::max(scaling[i], 1.0);
1066       }
1067     }
1068 
1069     auto output_flat = output->flat_outer_dims<T>();
1070     output_flat.setZero();
1071     std::vector<bool> is_modified(M, false);
1072 
1073     for (int64 i = 0; i < N; ++i) {
1074       const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
1075       OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
1076                   errors::InvalidArgument("Index ", output_idx,
1077                                           " out of range [0, ", M, ")."));
1078 
1079       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1080       OP_REQUIRES(
1081           context, FastBoundsCheck(idx, num_segments),
1082           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1083                                   num_segments, ")."));
1084 
1085       const T scale = static_cast<T>(scaling[idx]);
1086       if (is_modified[output_idx]) {
1087         if (scale == 1.0) {
1088           output_flat.template chip<0>(output_idx) +=
1089               input_flat.template chip<0>(idx);
1090         } else {
1091           output_flat.template chip<0>(output_idx) +=
1092               input_flat.template chip<0>(idx) * scale;
1093         }
1094       } else {
1095         if (scale == 1.0) {
1096           output_flat.template chip<0>(output_idx) =
1097               input_flat.template chip<0>(idx);
1098         } else {
1099           output_flat.template chip<0>(output_idx) =
1100               input_flat.template chip<0>(idx) * scale;
1101         }
1102       }
1103       is_modified[output_idx] = true;
1104     }
1105   }
1106 
1107  private:
1108   const bool is_sqrtn_;
1109 };
1110 
1111 template <class T>
1112 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
1113  public:
SparseSegmentMeanGradOp(OpKernelConstruction * context)1114   explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
1115       : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
1116 };
1117 
1118 template <class T>
1119 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
1120  public:
SparseSegmentSqrtNGradOp(OpKernelConstruction * context)1121   explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
1122       : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
1123 };
1124 
1125 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
1126   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad")       \
1127                               .Device(DEVICE_CPU)             \
1128                               .TypeConstraint<type>("T")      \
1129                               .TypeConstraint<int32>("Tidx"), \
1130                           SparseSegmentMeanGradOp<type>);
1131 REGISTER_CPU_SPARSE_KERNELS(float);
1132 REGISTER_CPU_SPARSE_KERNELS(double);
1133 #undef REGISTER_CPU_SPARSE_KERNELS
1134 
1135 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
1136   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad")      \
1137                               .Device(DEVICE_CPU)             \
1138                               .TypeConstraint<type>("T")      \
1139                               .TypeConstraint<int32>("Tidx"), \
1140                           SparseSegmentSqrtNGradOp<type>);
1141 REGISTER_CPU_SPARSE_KERNELS(float);
1142 REGISTER_CPU_SPARSE_KERNELS(double);
1143 #undef REGISTER_CPU_SPARSE_KERNELS
1144 }  // namespace tensorflow
1145