• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #define EIGEN_USE_GPU
24 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
25 
26 #include <vector>
27 
28 #include "third_party/eigen3/Eigen/Core"
29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/numeric_op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/framework/tensor_util.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/segment_reduction_ops.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/util/determinism.h"
42 #include "tensorflow/core/util/util.h"
43 
44 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
46 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 
48 #if GOOGLE_CUDA
49 #include "tensorflow/core/util/cuda_solvers.h"
50 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
51 
52 using stream_executor::cuda::ScopedActivateExecutorContext;
53 #elif TENSORFLOW_USE_ROCM
54 #include "tensorflow/core/platform/rocm.h"
55 #include "tensorflow/core/util/cuda_solvers.h"
56 using stream_executor::rocm::ScopedActivateExecutorContext;
57 #endif  // GOOGLE_CUDA
58 
59 namespace tensorflow {
60 
61 typedef Eigen::ThreadPoolDevice CPUDevice;
62 typedef Eigen::GpuDevice GPUDevice;
63 
64 namespace internal {
65 Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input,
66                                 const Tensor& segment_ids);
67 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
68                                         OpKernelContext* context,
69                                         const Tensor& data,
70                                         const Tensor& segment_ids,
71                                         const Tensor& num_segments);
72 Status ValidateSparseSegmentReduction(OpKernelContext* context,
73                                       const Tensor& input,
74                                       const Tensor& indices,
75                                       const Tensor& segment_ids,
76                                       bool has_num_segments);
77 }  // namespace internal
78 
79 // This operator handles reducing segments along the first dimension.
80 // See core/ops/math_ops.cc for more details.
81 template <typename Device, class T, class Index, typename Reducer,
82           int default_value>
83 class SegmentReductionOp : public OpKernel {
84  public:
SegmentReductionOp(OpKernelConstruction * context)85   explicit SegmentReductionOp(OpKernelConstruction* context)
86       : OpKernel(context) {}
87 
Compute(OpKernelContext * context)88   void Compute(OpKernelContext* context) override {
89     const Tensor& input = context->input(0);
90     const Tensor& segment_ids = context->input(1);
91 
92     OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input,
93                                                                segment_ids));
94 
95     const int64_t num_indices = segment_ids.NumElements();
96     auto input_flat = input.flat_outer_dims<T>();
97     const int64_t num_col = input_flat.dimension(1);
98 
99     const auto segment_vec = segment_ids.vec<Index>();
100     // Note that the current implementation assumes that segment_vec values are
101     // sorted.
102     const Index output_rows =
103         num_indices > 0
104             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
105             : 0;
106     OP_REQUIRES(context, output_rows >= 0,
107                 errors::InvalidArgument("segment ids must be >= 0"));
108 
109     OP_REQUIRES(context, input.dims() >= 1,
110                 errors::InvalidArgument("Shape must be at least rank 1"));
111 
112     TensorShape output_shape = input.shape();
113     output_shape.set_dim(0, output_rows);
114 
115     // Note that we do not initialize the output buffer with a default value, so
116     // we need to explicitly set missing indices to the default value.
117     Tensor* output = nullptr;
118     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
119     if (num_indices == 0) return;
120     OP_REQUIRES(context, output_rows > 0,
121                 errors::InvalidArgument("segment ids must be >= 0"));
122     auto output_flat = output->flat_outer_dims<T>();
123 
124 #if !defined(EIGEN_HAS_INDEX_LIST)
125     Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
126     dims_to_reduce[0] = 0;
127 #else
128     Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
129 #endif
130     Index start = 0, end = 1;
131 
132     Index uninitialized_index = 0;  // Index from which the output is not set.
133     Index out_index = internal::SubtleMustCopy(segment_vec(start));
134 
135     // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
136     // across threads.
137     Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
138     while (end <= num_indices) {
139       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
140       // used uninitialized in this function" in the Mac build (since the
141       // compiler isn't smart enough to realize the code is safe).
142       Index next_index = 0;
143       if (end < num_indices) {
144         next_index = internal::SubtleMustCopy(segment_vec(end));
145         if (out_index == next_index) {
146           ++end;
147           continue;
148         }
149         // We have a new segment here.  Verify that the segment ids are growing.
150         OP_REQUIRES(context, out_index < next_index,
151                     errors::InvalidArgument("segment ids are not increasing"));
152       }
153 
154       // Process segment [start, end)
155       const T* in_slice_ptr = &input_flat(start, 0);
156       typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
157                                Eigen::Unaligned>
158           OutT;
159 
160       OP_REQUIRES(
161           context, FastBoundsCheck(out_index, output_rows),
162           errors::InvalidArgument(
163               "Segment id ", out_index, " out of range [0, ", output_rows,
164               "), possibly because 'segment_ids' input is not sorted."));
165 
166       // If there is a gap between two indices, we need to set that gap to the
167       // default value.
168       if (out_index > uninitialized_index) {
169         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
170             out_index - uninitialized_index, num_col);
171         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
172             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
173         gap_slice.setConstant(T(default_value));
174       }
175 
176       T* out_slice_ptr = &output_flat(out_index, 0);
177       OutT out_slice(out_slice_ptr, out_slice_shape);
178       // We don't use out_slice.device(context->eigen_device<Device>)
179       // because these pieces of work are likely to be very small and
180       // the context switching overhead dwarfs any benefit we get from
181       // using another thread to do this work.
182       if (start == end - 1) {
183         typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
184                                  Eigen::Unaligned>
185             InT;
186         InT in_slice(in_slice_ptr, out_slice_shape);
187         out_slice = in_slice;
188       } else {
189         Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
190                                                            num_col);
191         typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
192                                  Eigen::Unaligned>
193             InT;
194         InT in_slice(in_slice_ptr, in_slice_shape);
195 
196         out_slice = in_slice.reduce(dims_to_reduce, Reducer());
197       }
198       if (end >= num_indices) break;
199       start = end;
200       ++end;
201       uninitialized_index = out_index + 1;
202       out_index = next_index;
203     }
204   }
205 };
206 
207 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
208 
209 //  SegmentReductionGPUOp is a segment reduction operator implemented for GPU
210 //  only.
211 //  TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
212 //  its unsorted counterpart (mostly when problem size is small).
213 //  This is due to the following two main reasons and a cost-effective way
214 //  to resolve these problems is desirable.
215 //  1. Sorted segment reduction requires a memory transfer from device to host
216 //     in order to know the size of the output dimension whereas unsorted
217 //     segment reduction receives the size of the output dimension as an input
218 //     parameter.
219 //  2. Sorted segment reduction is essentially a tiled version of unsorted
220 //     segment reduction and therefore such optimization comes at an inherent
221 //     cost. However such cost may not be justified when the problem size is
222 //     small. When to use the tiled version or the untiled version depends on
223 //     many factors including data alignments, ratio of calculation to memory
224 //     traffic and obviously, the problem sizes.
225 template <class T, class Index, class SegmentReductionFunctor>
226 class SegmentReductionGPUOp : public AsyncOpKernel {
227  public:
SegmentReductionGPUOp(OpKernelConstruction * context)228   explicit SegmentReductionGPUOp(OpKernelConstruction* context)
229       : AsyncOpKernel(context) {}
230 
ComputeAsync(OpKernelContext * context,DoneCallback done)231   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
232     const Tensor& input = context->input(0);
233     const Tensor& segment_ids = context->input(1);
234 
235     OP_REQUIRES_ASYNC(
236         context, TensorShapeUtils::IsVector(segment_ids.shape()),
237         errors::InvalidArgument("segment_ids should be a vector."), done);
238 
239     OP_REQUIRES_ASYNC(context, input.dims() >= 1,
240                       errors::InvalidArgument("Shape must be at least rank 1"),
241                       done);
242 
243     const int64_t num_indices = segment_ids.NumElements();
244     OP_REQUIRES_ASYNC(
245         context, num_indices == input.dim_size(0),
246         errors::InvalidArgument(
247             "segment_ids should be the same size as dimension 0 of"
248             " input."),
249         done);
250 
251     if (num_indices == 0) {
252       TensorShape output_shape = input.shape();
253       output_shape.set_dim(0, 0);
254 
255       Tensor* output = nullptr;
256       OP_REQUIRES_OK_ASYNC(
257           context, context->allocate_output(0, output_shape, &output), done);
258       done();
259       return;
260     }
261 
262     se::DeviceMemoryBase output_rows_device(
263         const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
264         (num_indices - 1));
265     ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
266 
267     auto stream = context->op_device_context()->stream();
268     OP_REQUIRES_ASYNC(
269         context,
270         stream
271             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
272                          sizeof(Index))
273             .ok(),
274         errors::Internal(type_string() +
275                          ": failed to copy output_rows from device"),
276         done);
277 
278     SegmentReductionFunctor functor_;
279     auto create_and_check_output = [context, output_rows_host, &input,
280                                     &segment_ids, &functor_, done]() {
281       // Ensure that within the callback, the proper GPU settings are
282       // configured.
283       auto stream = context->op_device_context()->stream();
284       ScopedActivateExecutorContext scoped_activation{stream->parent()};
285 
286       Index output_rows = *output_rows_host.data();
287       output_rows++;
288       OP_REQUIRES_ASYNC(context, output_rows > 0,
289                         errors::InvalidArgument("segment ids must be >= 0"),
290                         done);
291 
292       TensorShape output_shape = input.shape();
293       output_shape.set_dim(0, output_rows);
294 
295       Tensor* output = nullptr;
296       OP_REQUIRES_OK_ASYNC(
297           context, context->allocate_output(0, output_shape, &output), done);
298 
299       // The determinism check is here, rather than inside the functor (as it is
300       // for the unsorted segment reduction ops) because the done callback
301       // (required for OP_REQUIRES_ASYNC) is not available inside the functor.
302       bool determinism_requirement_met =
303           SegmentReductionFunctor::atomic_reduction_is_associative ||
304           !OpDeterminismRequired() ||
305           DisableSegmentReductionOpDeterminismExceptions();
306       OP_REQUIRES_ASYNC(
307           context, determinism_requirement_met,
308           errors::Unimplemented(
309               "Deterministic GPU implementation of sorted segment reduction op"
310               " not available."),
311           done);
312 
313       auto output_flat = output->flat_outer_dims<T>();
314       auto data_ptr = input.template flat<T>().data();
315       auto segment_flat = segment_ids.flat<Index>();
316       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
317                segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
318                output_flat);
319 
320       done();
321     };
322 
323     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
324         stream, create_and_check_output);
325   }
326 };
327 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
328 
329 // ____________________________________________________________________________
330 // Unsorted segment reduction ops.
331 
332 namespace functor {
333 
334 // The ReductionFunctor implementation for CPU.
335 template <typename T, typename Index, typename InitialValueF,
336           typename ReductionF>
337 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
338   void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
339                   typename TTypes<Index>::ConstFlat segment_ids,
340                   typename TTypes<T, 2>::ConstTensor data,
341                   typename TTypes<T, 2>::Tensor output) {
342     output.setConstant(InitialValueF()());
343     if (data.size() == 0) {
344       return;
345     }
346     const int64_t N = segment_ids.dimension(0);
347     const int64_t num_segments = output.dimension(0);
348     ReductionF reduction;
349     for (int64_t i = 0; i < N; ++i) {
350       Index j = internal::SubtleMustCopy(segment_ids(i));
351       if (j < 0) {
352         continue;
353       }
354       OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
355                   errors::InvalidArgument(
356                       "segment_ids", SliceDebugString(segment_ids_shape, i),
357                       " = ", j, " is out of range [0, ", num_segments, ")"));
358       reduction(data.template chip<0>(i), output.template chip<0>(j));
359     }
360   }
361 };
362 
363 template <typename T>
364 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
365 
366 template <typename T>
367 using constMatrixChip =
368     Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
369 
370 // reduction functors
371 template <typename T>
372 struct SumOp {
373   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
374     output += data;
375   }
376 };
377 
378 template <typename T>
379 struct MaxOp {
380   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
381     output = data.cwiseMax(output);
382   }
383 };
384 
385 template <typename T>
386 struct MinOp {
387   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
388     output = data.cwiseMin(output);
389   }
390 };
391 
392 template <typename T>
393 struct ProdOp {
394   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
395     output *= data;
396   }
397 };
398 }  // namespace functor
399 
400 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
401 // is the device specific implementation of the reduction. These device
402 // specific implementations are templated themselves with the corresponding
403 // initial value functors and reduction functors.
404 template <typename T, typename Index, typename DeviceReductionFunctor>
405 class UnsortedSegmentReductionOp : public OpKernel {
406  public:
407   explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
408       : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
409 
410   void Compute(OpKernelContext* context) override {
411     const Tensor& data = context->input(0);
412     const Tensor& segment_ids = context->input(1);
413     const Tensor& num_segments = context->input(2);
414     OP_REQUIRES_OK(context,
415                    internal::ValidateUnsortedSegmentReduction(
416                        this, context, data, segment_ids, num_segments));
417     const auto segment_flat = segment_ids.flat<Index>();
418     const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64>(
419         num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()()
420                                          : num_segments.scalar<int64>()()));
421     OP_REQUIRES(context, output_rows >= 0,
422                 errors::InvalidArgument("Input num_segments == ", output_rows,
423                                         " must not be negative."));
424     TensorShape output_shape;
425     output_shape.AddDim(output_rows);
426     for (int i = segment_ids.dims(); i < data.dims(); i++) {
427       output_shape.AddDim(data.dim_size(i));
428     }
429     Tensor* output = nullptr;
430     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
431     auto output_flat = output->flat_outer_dims<T>();
432     auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1);
433     reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat,
434                        output_flat);
435   }
436 
437  protected:
438   DeviceReductionFunctor reduction_functor_;
439 };
440 
441 // ____________________________________________________________________________
442 // Sparse segment reduction ops.
443 
444 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
445 // by two dense tensors, one containing the data, and the other containing
446 // indices into the data.
447 //
448 // The template parameters are:
449 // * Device: An Eigen device object, on which the kernel will execute.
450 // * T: The value type.
451 // * Index: The element type of the indices tensor (int32 or int64).
452 // * SegmentId: The element type of the segment_ids tensor (int32 or int64).
453 template <typename Device, class T, typename Index, typename SegmentId>
454 class SparseSegmentReductionOpBase : public OpKernel {
455  public:
456   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
457                                         bool is_mean, bool is_sqrtn,
458                                         bool has_num_segments, T default_value)
459       : OpKernel(context),
460         is_mean_(is_mean),
461         is_sqrtn_(is_sqrtn),
462         has_num_segments_(has_num_segments),
463         default_value_(default_value) {}
464 
465   void Compute(OpKernelContext* context) override {
466     const Tensor& input = context->input(0);
467     const Tensor& indices = context->input(1);
468     const Tensor& segment_ids = context->input(2);
469 
470     OP_REQUIRES_OK(
471         context, internal::ValidateSparseSegmentReduction(
472                      context, input, indices, segment_ids, has_num_segments_));
473 
474     Index output_rows = -1;
475     if (has_num_segments_) {
476       const Tensor& num_segments = context->input(3);
477       // Note that there is a Tnumsegments parameter on the op, but it is not
478       // plumbed through to here and so always takes its default value of int32.
479       output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
480     }
481     const int64_t num_indices = indices.NumElements();
482 
483     auto input_flat = input.flat_outer_dims<T>();
484     const int64_t num_col = input_flat.dimension(1);
485     const auto indices_vec = indices.vec<Index>();
486     const auto segment_vec = segment_ids.vec<SegmentId>();
487     // Note that the current implementation assumes that segment_vec values are
488     // sorted.
489     const SegmentId last_segment_id_plus_one =
490         num_indices > 0
491             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
492             : 0;
493     if (has_num_segments_) {
494       OP_REQUIRES(
495           context, output_rows >= last_segment_id_plus_one,
496           errors::InvalidArgument("segment ids must be < num_segments"));
497     } else {
498       output_rows = last_segment_id_plus_one;
499     }
500     OP_REQUIRES(context, output_rows >= 0,
501                 errors::InvalidArgument("segment ids must be >= 0"));
502 
503     TensorShape output_shape = input.shape();
504     output_shape.set_dim(0, output_rows);
505 
506     // Note that we do not initialize the output buffer with a default value, so
507     // we need to explicitly set missing indices to the default value.
508     Tensor* output = nullptr;
509     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
510     if (num_indices == 0) {
511       if (output_rows > 0) {
512         output->flat_outer_dims<T>().setConstant(default_value_);
513       }
514       return;
515     }
516     OP_REQUIRES(context, output_rows > 0,
517                 errors::InvalidArgument("segment ids must be >= 0"));
518     auto output_flat = output->flat_outer_dims<T>();
519 
520     Tensor temp;
521     if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) {
522       temp = tensorflow::Tensor(DT_FLOAT, output_shape);
523     }
524     auto temp_flat = temp.flat_outer_dims<float>();
525 
526     int64_t start = 0, end = 1;
527     // Index from which the output is not initialized.
528     SegmentId uninitialized_index = 0;
529     SegmentId out_index = internal::SubtleMustCopy(segment_vec(start));
530 
531     while (true) {
532       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
533       // used uninitialized in this function" in the Mac build (since the
534       // compiler isn't smart enough to realize the code is safe).
535       SegmentId next_index = 0;
536       if (end < num_indices) {
537         next_index = internal::SubtleMustCopy(segment_vec(end));
538         if (out_index == next_index) {
539           ++end;
540           continue;
541         }
542         // We have a new segment here.  Verify that the segment ids are growing.
543         OP_REQUIRES(context, out_index < next_index,
544                     errors::InvalidArgument("segment ids are not increasing"));
545       }
546 
547       OP_REQUIRES(
548           context, FastBoundsCheck(out_index, output_rows),
549           errors::InvalidArgument(
550               "Segment id ", out_index, " out of range [0, ", output_rows,
551               "), possibly because 'segment_ids' input is not sorted."));
552 
553       // If there is a gap between two indices, we need to set that gap to the
554       // default value.
555       if (out_index > uninitialized_index) {
556         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
557             out_index - uninitialized_index, num_col);
558         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
559             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
560         gap_slice.setConstant(default_value_);
561       }
562 
563       auto out = output_flat.template chip<0>(out_index);
564       auto temp = temp_flat.template chip<0>(out_index);
565       const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
566                                               end - start, out, temp);
567       OP_REQUIRES(context, bad_offset < 0,
568                   errors::InvalidArgument(
569                       "Bad: indices[", start + bad_offset,
570                       "] == ", indices_vec(start + bad_offset),
571                       " out of range [0, ", input_flat.dimension(0), ")"));
572 
573       start = end;
574       ++end;
575       uninitialized_index = out_index + 1;
576       out_index = next_index;
577       if (end > num_indices) break;
578     }
579 
580     // Fill the gap at the end with the default value.
581     if (uninitialized_index < output_rows) {
582       Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
583           output_rows - uninitialized_index, num_col);
584       Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
585           gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
586       gap_slice.setConstant(default_value_);
587     }
588   }
589 
590  private:
591   template <typename Tin>
592   using EnableIfBfloat16OrHalf =
593       typename std::enable_if<std::is_same<Tin, bfloat16>::value ||
594                                   std::is_same<Tin, Eigen::half>::value,
595                               int>::type;
596   template <typename Tin>
597   using EnableIfNotBfloat16OrHalf =
598       typename std::enable_if<!std::is_same<Tin, bfloat16>::value &&
599                                   !std::is_same<Tin, Eigen::half>::value,
600                               int>::type;
601 
602   template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
603   EIGEN_ALWAYS_INLINE auto fetch_val(
604       const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
605     return input_flat.template chip<0>(index);
606   }
607 
608   template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
609   EIGEN_ALWAYS_INLINE auto fetch_val(
610       const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
611     return input_flat.template chip<0>(index).template cast<float>();
612   }
613 
614   template <typename Tout>
615   EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) {
616     Tout m(1);
617     if (is_mean_ && (num < 10)) {
618       m = Tout(num);
619     }
620     if (is_sqrtn_ && (num < 10)) {
621       m = Tout(sqrt(num));
622     }
623     return Tout(1) / m;
624   }
625 
626   template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
627   int64 Reduce(
628       const typename TTypes<Tin>::ConstMatrix& input_flat,
629       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
630       int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
631       Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
632     return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
633                                         out, get_scaling_factor<Tin>(num));
634   }
635 
636   template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
637   int64 Reduce(
638       const typename TTypes<Tin>::ConstMatrix& input_flat,
639       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
640       int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
641       Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
642     int64_t res =
643         ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
644                                        temp, get_scaling_factor<float>(num));
645     out = temp.template cast<Tin>();
646     return res;
647   }
648 
649   template <typename Tin, typename Tindex, typename Tout>
650   int64 ReduceImpl(
651       const typename TTypes<Tin>::ConstMatrix& input_flat,
652       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
653       int64_t num,
654       Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
655       const Tout scaling_factor) {
656 #define INDEX(n, i)                               \
657   const auto index##n = indices_vec(start + (i)); \
658   if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
659 
660 #define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
661 
662     if (num == 1) {
663       INDEX(0, 0);
664       out = L(0);
665     } else {
666       int64_t r = num & 7;
667       switch (r) {
668         case 2: {
669           INDEX(0, 0);
670           INDEX(1, 1);
671           out = (L(0) + L(1)) * scaling_factor;
672           break;
673         }
674         case 3: {
675           INDEX(0, 0);
676           INDEX(1, 1);
677           INDEX(2, 2);
678           out = (L(0) + L(1) + L(2)) * scaling_factor;
679           break;
680         }
681         case 4: {
682           INDEX(0, 0);
683           INDEX(1, 1);
684           INDEX(2, 2);
685           INDEX(3, 3);
686           out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
687           break;
688         }
689         case 5: {
690           INDEX(0, 0);
691           INDEX(1, 1);
692           INDEX(2, 2);
693           INDEX(3, 3);
694           INDEX(4, 4);
695           out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
696           break;
697         }
698         case 6: {
699           INDEX(0, 0);
700           INDEX(1, 1);
701           INDEX(2, 2);
702           INDEX(3, 3);
703           INDEX(4, 4);
704           INDEX(5, 5);
705           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
706           break;
707         }
708         case 7: {
709           INDEX(0, 0);
710           INDEX(1, 1);
711           INDEX(2, 2);
712           INDEX(3, 3);
713           INDEX(4, 4);
714           INDEX(5, 5);
715           INDEX(6, 6);
716           out =
717               (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
718           break;
719         }
720         case 0: {
721           INDEX(0, 0);
722           INDEX(1, 1);
723           INDEX(2, 2);
724           INDEX(3, 3);
725           INDEX(4, 4);
726           INDEX(5, 5);
727           INDEX(6, 6);
728           INDEX(7, 7);
729           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
730                 scaling_factor;
731           r = 8;
732           break;
733         }
734         case 1: {
735           INDEX(0, 0);
736           INDEX(1, 1);
737           INDEX(2, 2);
738           INDEX(3, 3);
739           INDEX(4, 4);
740           INDEX(5, 5);
741           INDEX(6, 6);
742           INDEX(7, 7);
743           INDEX(8, 8);
744           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
745                 scaling_factor;
746           r = 9;
747           break;
748         }
749       }
750       for (; r < num; r += 8) {
751         INDEX(0, r);
752         INDEX(1, r + 1);
753         INDEX(2, r + 2);
754         INDEX(3, r + 3);
755         INDEX(4, r + 4);
756         INDEX(5, r + 5);
757         INDEX(6, r + 6);
758         INDEX(7, r + 7);
759         out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
760       }
761       if (is_mean_ && num >= 10) {
762         out = out / static_cast<Tout>(num);
763       }
764       if (is_sqrtn_ && num >= 10) {
765         out = out / static_cast<Tout>(sqrt(num));
766       }
767     }
768 
769     return -1;
770 #undef L
771 #undef INDEX
772   }
773 
774   const bool is_mean_;
775   const bool is_sqrtn_;
776   const bool has_num_segments_;
777   const T default_value_;
778 };
779 
780 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
781 
782 // Specialization for GPU. Must be Async because may need to wait for a host to
783 // device memcpy before allocating output.
784 template <class T, typename Index, typename SegmentId>
785 class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId>
786     : public AsyncOpKernel {
787  public:
788   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
789                                         bool is_mean, bool is_sqrtn,
790                                         bool has_num_segments, T default_value)
791       : AsyncOpKernel(context),
792         is_mean_(is_mean),
793         is_sqrtn_(is_sqrtn),
794         has_num_segments_(has_num_segments),
795         default_value_(default_value) {}
796 
797   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
798     const Tensor& input = context->input(0);
799     const Tensor& indices = context->input(1);
800     const Tensor& segment_ids = context->input(2);
801 
802     OP_REQUIRES_OK_ASYNC(
803         context,
804         internal::ValidateSparseSegmentReduction(
805             context, input, indices, segment_ids, has_num_segments_),
806         done);
807 
808     ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true);
809 
810     auto create_and_check_output = [this, context, input, indices, segment_ids,
811                                     last_segment_id_host, done]() {
812       // Ensure that within the callback, the proper GPU settings are
813       // configured.
814       auto stream = context->op_device_context()->stream();
815       ScopedActivateExecutorContext scoped_activation{stream->parent()};
816 
817       SegmentId last_segment_id = *last_segment_id_host.data();
818       SegmentId output_rows = last_segment_id + 1;
819       OP_REQUIRES_ASYNC(context, output_rows > 0,
820                         errors::InvalidArgument("segment ids must be >= 0"),
821                         done);
822 
823       TensorShape output_shape = input.shape();
824       output_shape.set_dim(0, output_rows);
825 
826       Tensor* output = nullptr;
827       OP_REQUIRES_OK_ASYNC(
828           context, context->allocate_output(0, output_shape, &output), done);
829 
830       auto input_flat = input.flat_outer_dims<T>();
831       const auto indices_vec = indices.vec<Index>();
832       const auto segment_ids_vec = segment_ids.vec<SegmentId>();
833       auto output_flat = output->flat_outer_dims<T>();
834 
835       functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor;
836       OP_REQUIRES_OK_ASYNC(
837           context,
838           functor(context, is_mean_, is_sqrtn_, default_value_, input_flat,
839                   indices_vec, segment_ids_vec, output_flat),
840           done);
841       done();
842     };
843 
844     if (has_num_segments_) {
845       // No need to do any device to host memcpy, just compute synchronously.
846       const Tensor& num_segments_t = context->input(3);
847       SegmentId num_segments =
848           internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32
849                                        ? num_segments_t.scalar<int32>()()
850                                        : num_segments_t.scalar<int64>()());
851       *last_segment_id_host.mutable_data() = num_segments - 1;
852       create_and_check_output();
853     } else {
854       const int64_t num_indices = indices.NumElements();
855       // Need to copy last element of segment_ids from device to host, and then
856       // asynchronously allocate the output and finish the computation.
857       se::DeviceMemoryBase last_segment_id_device(
858           const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() +
859           (num_indices - 1));
860       auto stream = context->op_device_context()->stream();
861       OP_REQUIRES_ASYNC(
862           context,
863           stream
864               ->ThenMemcpy(last_segment_id_host.mutable_data(),
865                            last_segment_id_device, sizeof(SegmentId))
866               .ok(),
867           errors::Internal(type_string() +
868                            ": failed to copy last_segment_id from device"),
869           done);
870       context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
871           stream, create_and_check_output);
872     }
873   }
874 
875  private:
876   const bool is_mean_;
877   const bool is_sqrtn_;
878   const bool has_num_segments_;
879   const T default_value_;
880 };
881 
882 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
883 
884 template <typename Device, class T, typename Index, typename SegmentId>
885 class SparseSegmentReductionMeanOp
886     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
887  public:
888   explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
889       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
890             context, true /*is_mean*/, false /*is_sqrtn*/,
891             false /* has_num_segments */, T(0) /* default_value */) {}
892 };
893 
894 template <typename Device, class T, typename Index, typename SegmentId>
895 class SparseSegmentReductionMeanWithNumSegmentsOp
896     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
897  public:
898   explicit SparseSegmentReductionMeanWithNumSegmentsOp(
899       OpKernelConstruction* context)
900       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
901             context, true /*is_mean*/, false /*is_sqrtn*/,
902             true /* has_num_segments */, T(0) /* default_value */) {}
903 };
904 
905 template <typename Device, class T, typename Index, typename SegmentId>
906 class SparseSegmentReductionSqrtNOp
907     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
908  public:
909   explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
910       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
911             context, false /*is_mean*/, true /*is_sqrtn*/,
912             false /* has_num_segments */, T(0) /* default_value */) {}
913 };
914 
915 template <typename Device, class T, typename Index, typename SegmentId>
916 class SparseSegmentReductionSqrtNWithNumSegmentsOp
917     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
918  public:
919   explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
920       OpKernelConstruction* context)
921       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
922             context, false /*is_mean*/, true /*is_sqrtn*/,
923             true /* has_num_segments */, T(0) /* default_value */) {}
924 };
925 
926 template <typename Device, class T, typename Index, typename SegmentId>
927 class SparseSegmentReductionSumOp
928     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
929  public:
930   explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
931       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
932             context, false /*is_mean*/, false /*is_sqrtn*/,
933             false /* has_num_segments */, T(0) /* default_value */) {}
934 };
935 
936 template <typename Device, class T, typename Index, typename SegmentId>
937 class SparseSegmentReductionSumWithNumSegmentsOp
938     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
939  public:
940   explicit SparseSegmentReductionSumWithNumSegmentsOp(
941       OpKernelConstruction* context)
942       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
943             context, false /*is_mean*/, false /*is_sqrtn*/,
944             true /* has_num_segments */, T(0) /* default_value */) {}
945 };
946 
947 namespace functor {
948 
949 template <typename T, typename Index, typename SegmentId>
950 struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> {
951   void operator()(OpKernelContext* context,
952                   SparseSegmentReductionOperation operation,
953                   typename TTypes<T>::ConstMatrix input_flat,
954                   typename TTypes<Index>::ConstVec indices_vec,
955                   typename TTypes<SegmentId>::ConstVec segment_vec,
956                   typename TTypes<T>::Matrix output_flat) {
957     const int64_t N = indices_vec.size();
958     const SegmentId M = output_flat.dimension(0);
959 
960     // Note that similar to SparseSegmentMean, we assume that segment_vec is
961     // already sorted and has non-negative values.
962     const SegmentId num_segments = input_flat.dimension(0);
963     const SegmentId last_segment_id_plus_one =
964         internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
965     OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
966                 errors::InvalidArgument("Invalid number of segments"));
967 
968     // Compute scaling factors for input.
969     std::vector<double> scaling(
970         (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments),
971         0.0);
972     if (operation != SparseSegmentReductionOperation::kSum) {
973       for (int64_t i = 0; i < N; ++i) {
974         const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
975         OP_REQUIRES(
976             context, FastBoundsCheck(idx, num_segments),
977             errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
978                                     num_segments, ")."));
979         scaling[idx] += 1;
980       }
981       for (size_t i = 0; i < scaling.size(); ++i) {
982         switch (operation) {
983           case SparseSegmentReductionOperation::kSum: {
984             OP_REQUIRES(
985                 context, false,
986                 errors::Internal(
987                     "Should not happen: sum inside SparseSegmentReductionOp "
988                     "scaling generation."));
989           }
990           case SparseSegmentReductionOperation::kMean: {
991             scaling[i] = 1.0 / std::max(scaling[i], 1.0);
992             break;
993           }
994           case SparseSegmentReductionOperation::kSqrtN: {
995             scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
996             break;
997           }
998             // No default to get compiler warnings for missing cases.
999         }
1000       }
1001     }
1002 
1003     output_flat.setZero();
1004     std::vector<bool> is_modified(M, false);
1005 
1006     for (int64_t i = 0; i < N; ++i) {
1007       const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
1008       OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
1009                   errors::InvalidArgument("Index ", output_idx,
1010                                           " out of range [0, ", M, ")."));
1011 
1012       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1013       OP_REQUIRES(
1014           context, FastBoundsCheck(idx, num_segments),
1015           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1016                                   num_segments, ")."));
1017 
1018       const T scale = (operation == SparseSegmentReductionOperation::kSum
1019                            ? static_cast<T>(1)
1020                            : static_cast<T>(scaling[idx]));
1021       if (is_modified[output_idx]) {
1022         if (scale == 1.0) {
1023           output_flat.template chip<0>(output_idx) +=
1024               input_flat.template chip<0>(idx);
1025         } else {
1026           output_flat.template chip<0>(output_idx) +=
1027               input_flat.template chip<0>(idx) * scale;
1028         }
1029       } else {
1030         if (scale == 1.0) {
1031           output_flat.template chip<0>(output_idx) =
1032               input_flat.template chip<0>(idx);
1033         } else {
1034           output_flat.template chip<0>(output_idx) =
1035               input_flat.template chip<0>(idx) * scale;
1036         }
1037       }
1038       is_modified[output_idx] = true;
1039     }
1040   }
1041 };
1042 
1043 }  // namespace functor
1044 
1045 // Implements the common logic for the gradients of SparseSegmentReduction
1046 // kernels.
1047 //
1048 // The template parameters are:
1049 // * Device: An Eigen device object, on which the kernel will execute.
1050 // * T: The value type.
1051 // * Index: The element type of the indices tensor (int32 or int64).
1052 // * SegmentId: The element type of the segment_ids tensor (int32 or int64).
1053 template <typename Device, class T, typename Index, typename SegmentId>
1054 class SparseSegmentGradOpBase : public OpKernel {
1055  public:
1056   explicit SparseSegmentGradOpBase(OpKernelConstruction* context,
1057                                    SparseSegmentReductionOperation operation)
1058       : OpKernel(context), operation_(operation) {}
1059 
1060   void Compute(OpKernelContext* context) override {
1061     const Tensor& input = context->input(0);
1062     const Tensor& indices = context->input(1);
1063     const Tensor& segment_ids = context->input(2);
1064     const Tensor& output_dim0 = context->input(3);
1065 
1066     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
1067                 errors::InvalidArgument("indices should be a vector."));
1068     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
1069                 errors::InvalidArgument("segment_ids should be a vector."));
1070     OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()),
1071                 errors::InvalidArgument("output_dim0 should be a scalar."));
1072 
1073     const int64_t N = indices.NumElements();
1074     OP_REQUIRES(context, N == segment_ids.NumElements(),
1075                 errors::InvalidArgument(
1076                     "segment_ids and indices should have same size."));
1077     const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()());
1078 
1079     auto input_flat = input.flat_outer_dims<T>();
1080     const auto indices_vec = indices.vec<Index>();
1081     const auto segment_vec = segment_ids.vec<SegmentId>();
1082 
1083     TensorShape output_shape = input.shape();
1084     output_shape.set_dim(0, M);
1085     Tensor* output = nullptr;
1086     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1087     if (M == 0 || N == 0) return;
1088 
1089     auto output_flat = output->flat_outer_dims<T>();
1090     functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()(
1091         context, operation_, input_flat, indices_vec, segment_vec, output_flat);
1092   }
1093 
1094  private:
1095   const SparseSegmentReductionOperation operation_;
1096 };
1097 
1098 template <typename Device, class T, typename Index, typename SegmentId>
1099 class SparseSegmentSumGradOp
1100     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1101  public:
1102   explicit SparseSegmentSumGradOp(OpKernelConstruction* context)
1103       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1104             context, SparseSegmentReductionOperation::kSum) {}
1105 };
1106 
1107 template <typename Device, class T, typename Index, typename SegmentId>
1108 class SparseSegmentMeanGradOp
1109     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1110  public:
1111   explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
1112       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1113             context, SparseSegmentReductionOperation::kMean) {}
1114 };
1115 
1116 template <typename Device, class T, typename Index, typename SegmentId>
1117 class SparseSegmentSqrtNGradOp
1118     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1119  public:
1120   explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
1121       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1122             context, SparseSegmentReductionOperation::kSqrtN) {}
1123 };
1124 
1125 }  // namespace tensorflow
1126 
1127 #endif  // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
1128