1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA
22
23 #include "third_party/eigen3/Eigen/Core"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25
26 #include "tensorflow/core/kernels/segment_reduction_ops.h"
27 #include <vector>
28
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/numeric_op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/util/util.h"
39
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/kernels/cuda_solvers.h"
43 #include "tensorflow/core/platform/cuda.h"
44
45 using stream_executor::cuda::ScopedActivateExecutorContext;
46 #endif // GOOGLE_CUDA
47
48 namespace tensorflow {
49
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef Eigen::GpuDevice GPUDevice;
52
53 // Static routines not in the templated class to reduce code size
SegmentReductionValidationHelper(OpKernelContext * context,const Tensor & input,const Tensor & segment_ids)54 static void SegmentReductionValidationHelper(OpKernelContext* context,
55 const Tensor& input,
56 const Tensor& segment_ids) {
57 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
58 errors::InvalidArgument("segment_ids should be a vector."));
59 const int64 num_indices = segment_ids.NumElements();
60 OP_REQUIRES(context, num_indices == input.dim_size(0),
61 errors::InvalidArgument(
62 "segment_ids should be the same size as dimension 0 of"
63 " input."));
64 }
65
SegmentReductionDoValidation(OpKernelContext * c,const Tensor & input,const Tensor & segment_ids)66 static bool SegmentReductionDoValidation(OpKernelContext* c,
67 const Tensor& input,
68 const Tensor& segment_ids) {
69 SegmentReductionValidationHelper(c, input, segment_ids);
70 return c->status().ok();
71 }
72
73 // This operator handles reducing segments along the first dimension.
74 // See core/ops/math_ops.cc for more details.
75 template <typename Device, class T, class Index, typename Reducer,
76 int default_value>
77 class SegmentReductionOp : public OpKernel {
78 public:
SegmentReductionOp(OpKernelConstruction * context)79 explicit SegmentReductionOp(OpKernelConstruction* context)
80 : OpKernel(context) {}
81
Compute(OpKernelContext * context)82 void Compute(OpKernelContext* context) override {
83 const Tensor& input = context->input(0);
84 const Tensor& segment_ids = context->input(1);
85
86 if (!SegmentReductionDoValidation(context, input, segment_ids)) {
87 return;
88 }
89
90 const int64 num_indices = segment_ids.NumElements();
91 auto input_flat = input.flat_outer_dims<T>();
92 const int64 num_col = input_flat.dimension(1);
93
94 const auto segment_vec = segment_ids.vec<Index>();
95 // Note that the current implementation assumes that segment_vec values are
96 // sorted.
97 const Index output_rows =
98 num_indices > 0
99 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
100 : 0;
101 OP_REQUIRES(context, output_rows >= 0,
102 errors::InvalidArgument("segment ids must be >= 0"));
103
104 TensorShape output_shape = input.shape();
105 output_shape.set_dim(0, output_rows);
106
107 // Note that we do not initialize the output buffer with a default value, so
108 // we need to explicitly set missing indices to the default value.
109 Tensor* output = nullptr;
110 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
111 if (num_indices == 0) return;
112 OP_REQUIRES(context, output_rows > 0,
113 errors::InvalidArgument("segment ids must be >= 0"));
114 auto output_flat = output->flat_outer_dims<T>();
115
116 #if !defined(EIGEN_HAS_INDEX_LIST)
117 Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
118 dims_to_reduce[0] = 0;
119 #else
120 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
121 #endif
122 Index start = 0, end = 1;
123
124 Index uninitialized_index = 0; // Index from which the output is not set.
125 Index out_index = internal::SubtleMustCopy(segment_vec(start));
126
127 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
128 // across threads.
129 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
130 while (end <= num_indices) {
131 // We initialize next_index to 0 to avoid "warning: 'next_index' may be
132 // used uninitialized in this function" in the Mac build (since the
133 // compiler isn't smart enough to realize the code is safe).
134 Index next_index = 0;
135 if (end < num_indices) {
136 next_index = internal::SubtleMustCopy(segment_vec(end));
137 if (out_index == next_index) {
138 ++end;
139 continue;
140 }
141 // We have a new segment here. Verify that the segment ids are growing.
142 OP_REQUIRES(context, out_index < next_index,
143 errors::InvalidArgument("segment ids are not increasing"));
144 }
145
146 // Process segment [start, end)
147 const T* in_slice_ptr = &input_flat(start, 0);
148 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
149 Eigen::Unaligned>
150 OutT;
151
152 OP_REQUIRES(
153 context, FastBoundsCheck(out_index, output_rows),
154 errors::InvalidArgument(
155 "Segment id ", out_index, " out of range [0, ", output_rows,
156 "), possibly because 'segment_ids' input is not sorted."));
157
158 // If there is a gap between two indices, we need to set that gap to the
159 // default value.
160 if (out_index > uninitialized_index) {
161 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
162 out_index - uninitialized_index, num_col);
163 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
164 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
165 gap_slice.setConstant(T(default_value));
166 }
167
168 T* out_slice_ptr = &output_flat(out_index, 0);
169 OutT out_slice(out_slice_ptr, out_slice_shape);
170 // We don't use out_slice.device(context->eigen_device<Device>)
171 // because these pieces of work are likely to be very small and
172 // the context switching overhead dwarfs any benefit we get from
173 // using another thread to do this work.
174 if (start == end - 1) {
175 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
176 Eigen::Unaligned>
177 InT;
178 InT in_slice(in_slice_ptr, out_slice_shape);
179 out_slice = in_slice;
180 } else {
181 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
182 num_col);
183 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
184 Eigen::Unaligned>
185 InT;
186 InT in_slice(in_slice_ptr, in_slice_shape);
187
188 out_slice = in_slice.reduce(dims_to_reduce, Reducer());
189 }
190 if (end >= num_indices) break;
191 start = end;
192 ++end;
193 uninitialized_index = out_index + 1;
194 out_index = next_index;
195 }
196 }
197 };
198
199 #ifdef GOOGLE_CUDA
200 // SegmentSumGPUOp is a segment sum operator implemented for GPU only.
201 // TODO: This implementation of SegmentSumGPUOp is sometimes slower than
202 // its unsorted counterpart (mostly when problem size is small).
203 // This is due to the following two main reasons and a cost-effective way
204 // to resolve these problems is desirable.
205 // 1. Sorted segment sum requires a memory transfer from device to host in
206 // order to know the size of the output dimension whereas unsorted segment
207 // sum receives the size of the output dimension as an input parameter.
208 // 2. Sorted segment sum is essentially a tiled version of unsorted segment
209 // sum and therefore such optimization comes at an inherent cost. However
210 // such cost may not be justified when the problem size is small. When to
211 // use the tiled version or the untiled version depends on many factors
212 // including data alignments, ratio of calculation to memory traffic and
213 // obviously, the problem sizes.
214 template <class T, class Index>
215 class SegmentSumGPUOp : public AsyncOpKernel {
216 public:
SegmentSumGPUOp(OpKernelConstruction * context)217 explicit SegmentSumGPUOp(OpKernelConstruction* context)
218 : AsyncOpKernel(context) {}
219
ComputeAsync(OpKernelContext * context,DoneCallback done)220 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
221 const Tensor& input = context->input(0);
222 const Tensor& segment_ids = context->input(1);
223
224 OP_REQUIRES_ASYNC(
225 context, TensorShapeUtils::IsVector(segment_ids.shape()),
226 errors::InvalidArgument("segment_ids should be a vector."), done);
227
228 const int64 num_indices = segment_ids.NumElements();
229 OP_REQUIRES_ASYNC(
230 context, num_indices == input.dim_size(0),
231 errors::InvalidArgument(
232 "segment_ids should be the same size as dimension 0 of"
233 " input."),
234 done);
235
236 if (num_indices == 0) {
237 TensorShape output_shape = input.shape();
238 output_shape.set_dim(0, 0);
239
240 Tensor* output = nullptr;
241 OP_REQUIRES_OK_ASYNC(
242 context, context->allocate_output(0, output_shape, &output), done);
243 done();
244 return;
245 }
246
247 se::DeviceMemoryBase output_rows_device(
248 const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
249 (num_indices - 1));
250 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
251
252 auto stream = context->op_device_context()->stream();
253 OP_REQUIRES_ASYNC(
254 context,
255 stream
256 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
257 sizeof(Index))
258 .ok(),
259 errors::Internal(
260 "SegmentSumGPUOp: failed to copy output_rows from device"),
261 done);
262
263 functor::SegmentSumFunctor<T, Index> functor_;
264 auto create_and_check_output = [context, output_rows_host, &input,
265 &segment_ids, &functor_, done]() {
266 // Ensure that within the callback, the proper GPU settings are
267 // configured.
268 auto stream = context->op_device_context()->stream();
269 ScopedActivateExecutorContext scoped_activation{stream->parent()};
270
271 Index output_rows = *output_rows_host.data();
272 output_rows++;
273 OP_REQUIRES_ASYNC(context, output_rows > 0,
274 errors::InvalidArgument("segment ids must be >= 0"),
275 done);
276
277 TensorShape output_shape = input.shape();
278 output_shape.set_dim(0, output_rows);
279
280 Tensor* output = nullptr;
281 OP_REQUIRES_OK_ASYNC(
282 context, context->allocate_output(0, output_shape, &output), done);
283
284 auto output_flat = output->flat_outer_dims<T>();
285 auto data_ptr = input.template flat<T>().data();
286 auto segment_flat = segment_ids.flat<Index>();
287 functor_(context, context->eigen_device<GPUDevice>(), output_rows,
288 segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
289 output_flat);
290
291 done();
292 };
293
294 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
295 stream, create_and_check_output);
296 }
297 };
298 #endif // GOOGLE_CUDA
299
300 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
301 default_value) \
302 REGISTER_KERNEL_BUILDER( \
303 Name(name) \
304 .Device(DEVICE_CPU) \
305 .TypeConstraint<type>("T") \
306 .TypeConstraint<index_type>("Tindices"), \
307 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
308
309 #define REGISTER_REAL_CPU_KERNELS(type, index_type) \
310 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
311 type, index_type, 0); \
312 REGISTER_CPU_KERNEL_SEGMENT( \
313 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
314 REGISTER_CPU_KERNEL_SEGMENT( \
315 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
316 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
317 type, index_type, 0); \
318 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
319 type, index_type, 0)
320
321 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \
322 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
323 type, index_type, 0); \
324 REGISTER_CPU_KERNEL_SEGMENT( \
325 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
326 REGISTER_CPU_KERNEL_SEGMENT( \
327 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
328
329 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
330 REGISTER_REAL_CPU_KERNELS(type, int32); \
331 REGISTER_REAL_CPU_KERNELS(type, int64)
332
333 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
334 REGISTER_COMPLEX_CPU_KERNELS(type, int32); \
335 REGISTER_COMPLEX_CPU_KERNELS(type, int64)
336
337 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
338 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
339 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
340 #undef REGISTER_CPU_KERNEL_SEGMENT
341 #undef REGISTER_REAL_CPU_KERNELS
342 #undef REGISTER_COMPLEX_CPU_KERNELS
343 #undef REGISTER_REAL_CPU_KERNELS_ALL
344 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
345
346 #if GOOGLE_CUDA
347 #define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
348 REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
349 .Device(DEVICE_GPU) \
350 .TypeConstraint<type>("T") \
351 .TypeConstraint<index_type>("Tindices"), \
352 SegmentSumGPUOp<type, index_type>)
353
354 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
355 REGISTER_GPU_SORTED_KERNELS(type, int32); \
356 REGISTER_GPU_SORTED_KERNELS(type, int64);
357
358 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
359 #undef REGISTER_GPU_SORTED_KERNELS
360 #undef REGISTER_GPU_SORTED_KERNELS_ALL
361 #endif // GOOGLE_CUDA
362
363 // ____________________________________________________________________________
364 // Unsorted segment reduction ops.
365
366 namespace functor {
367
368 // The ReductionFunctor implementation for CPU.
369 template <typename T, typename Index, typename InitialValueF,
370 typename ReductionF>
371 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
operator ()tensorflow::functor::UnsortedSegmentFunctor372 void operator()(OpKernelContext* ctx, const Index num_segments,
373 const TensorShape& segment_ids_shape,
374 typename TTypes<Index>::ConstFlat segment_ids,
375 const Index data_size, const T* data,
376 typename TTypes<T, 2>::Tensor output) {
377 output.setConstant(InitialValueF()());
378 if (data_size == 0) {
379 return;
380 }
381 const int64 N = segment_ids.dimension(0);
382 ReductionF reduction;
383 auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
384 for (int64 i = 0; i < N; ++i) {
385 Index j = internal::SubtleMustCopy(segment_ids(i));
386 if (j < 0) {
387 continue;
388 }
389 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
390 errors::InvalidArgument(
391 "segment_ids", SliceDebugString(segment_ids_shape, i),
392 " = ", j, " is out of range [0, ", num_segments, ")"));
393 reduction(data_flat.template chip<0>(i), output.template chip<0>(j));
394 }
395 }
396 };
397
398 template <typename T>
399 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
400
401 template <typename T>
402 using constMatrixChip =
403 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
404
405 // reduction functors
406 template <typename T>
407 struct SumOp {
operator ()tensorflow::functor::SumOp408 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
409 output += data;
410 }
411 };
412
413 template <typename T>
414 struct MaxOp {
operator ()tensorflow::functor::MaxOp415 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
416 output = data.cwiseMax(output);
417 }
418 };
419
420 template <typename T>
421 struct MinOp {
operator ()tensorflow::functor::MinOp422 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
423 output = data.cwiseMin(output);
424 }
425 };
426
427 template <typename T>
428 struct ProdOp {
operator ()tensorflow::functor::ProdOp429 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
430 output *= data;
431 }
432 };
433 } // namespace functor
434
435 // Static check routines not in the templated class to reduce code size
UnsortedSegmentReductionValidation(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)436 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel,
437 OpKernelContext* context,
438 const Tensor& data,
439 const Tensor& segment_ids,
440 const Tensor& num_segments) {
441 OP_REQUIRES(
442 context, op_kernel->IsLegacyScalar(num_segments.shape()),
443 errors::InvalidArgument("num_segments should be a scalar, not shape ",
444 num_segments.shape().DebugString()));
445 OP_REQUIRES(
446 context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
447 errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
448 " does not start with segment_ids.shape = ",
449 segment_ids.shape().DebugString()));
450 }
451
UnsortedSegmentReductionDoValidation(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)452 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel,
453 OpKernelContext* context,
454 const Tensor& data,
455 const Tensor& segment_ids,
456 const Tensor& num_segments) {
457 UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids,
458 num_segments);
459 return context->status().ok();
460 }
461
462 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
463 // is the device specific implementation of the reduction. These device
464 // specific implementations are templated themselves with the corresponding
465 // initial value functors and reduction functors.
466 template <typename T, typename Index, typename DeviceReductionFunctor>
467 class UnsortedSegmentReductionOp : public OpKernel {
468 public:
UnsortedSegmentReductionOp(OpKernelConstruction * context)469 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
470 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
471
Compute(OpKernelContext * context)472 void Compute(OpKernelContext* context) override {
473 const Tensor& data = context->input(0);
474 const Tensor& segment_ids = context->input(1);
475 const Tensor& num_segments = context->input(2);
476 if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids,
477 num_segments)) {
478 return;
479 }
480 const auto segment_flat = segment_ids.flat<Index>();
481 const Index output_rows =
482 internal::SubtleMustCopy(num_segments.scalar<int32>()());
483 OP_REQUIRES(context, output_rows >= 0,
484 errors::InvalidArgument("Input num_segments == ", output_rows,
485 " must not be negative."));
486 TensorShape output_shape;
487 output_shape.AddDim(output_rows);
488 for (int i = segment_ids.dims(); i < data.dims(); i++) {
489 output_shape.AddDim(data.dim_size(i));
490 }
491 Tensor* output = nullptr;
492 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
493 auto output_flat = output->flat_outer_dims<T>();
494 auto data_ptr = data.template flat<T>().data();
495 reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat,
496 data.NumElements(), data_ptr, output_flat);
497 }
498
499 protected:
500 DeviceReductionFunctor reduction_functor_;
501 };
502
503 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \
504 name, type, index_type, initial_value_functor, reduction_functor) \
505 REGISTER_KERNEL_BUILDER( \
506 Name(name) \
507 .Device(DEVICE_CPU) \
508 .TypeConstraint<type>("T") \
509 .TypeConstraint<index_type>("Tindices"), \
510 UnsortedSegmentReductionOp< \
511 type, index_type, \
512 functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
513 initial_value_functor, \
514 reduction_functor> >)
515
516 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
517 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
518 functor::Zero<type>, \
519 functor::SumOp<type>); \
520 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
521 functor::Lowest<type>, \
522 functor::MaxOp<type>); \
523 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
524 functor::Highest<type>, \
525 functor::MinOp<type>); \
526 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
527 functor::One<type>, \
528 functor::ProdOp<type>);
529
530 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
531 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
532 functor::Zero<type>, \
533 functor::SumOp<type>); \
534 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
535 functor::One<type>, \
536 functor::ProdOp<type>)
537
538 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
539 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \
540 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
541
542 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
543 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \
544 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
545
546 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
547 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
548 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
549
550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS
551 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
552 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
553 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
554 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
555
556 #if GOOGLE_CUDA
557 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \
558 name, type, index_type, initial_value_functor, reduction_kernel_functor) \
559 REGISTER_KERNEL_BUILDER( \
560 Name(name) \
561 .Device(DEVICE_GPU) \
562 .HostMemory("num_segments") \
563 .TypeConstraint<type>("T") \
564 .TypeConstraint<index_type>("Tindices"), \
565 UnsortedSegmentReductionOp< \
566 type, index_type, \
567 functor::UnsortedSegmentFunctor<GPUDevice, type, index_type, \
568 initial_value_functor, \
569 reduction_kernel_functor> >)
570
571 // sum is the only op that supports all input types currently
572 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
573 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
574 functor::Lowest<type>, \
575 functor::MaxOpGpu<type>); \
576 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
577 functor::Highest<type>, \
578 functor::MinOpGpu<type>); \
579 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
580 functor::One<type>, \
581 functor::ProdOpGpu<type>);
582
583 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
584 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
585 functor::Zero<type>, \
586 functor::SumOpGpu<type>);
587
588 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
589 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32); \
590 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64);
591
592 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
593 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32); \
594 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64);
595
596
597 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
598 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
600 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
601 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
602 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
603
604 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
605 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS
606 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS
607 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
608 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
609
610 #endif // GOOGLE_CUDA
611
612 // ____________________________________________________________________________
613 // Sparse segment reduction ops.
614
615 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
616 // by two dense tensors, one containing the data, and the other containing
617 // indices into the data.
618 template <typename Device, class T>
619 class SparseSegmentReductionOpBase : public OpKernel {
620 public:
SparseSegmentReductionOpBase(OpKernelConstruction * context,bool is_mean,bool is_sqrtn,bool has_num_segments,T default_value)621 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
622 bool is_mean, bool is_sqrtn,
623 bool has_num_segments, T default_value)
624 : OpKernel(context),
625 is_mean_(is_mean),
626 is_sqrtn_(is_sqrtn),
627 has_num_segments_(has_num_segments),
628 default_value_(default_value) {}
629
Compute(OpKernelContext * context)630 void Compute(OpKernelContext* context) override {
631 const Tensor& input = context->input(0);
632 const Tensor& indices = context->input(1);
633 const Tensor& segment_ids = context->input(2);
634
635 Index output_rows = -1;
636 if (has_num_segments_) {
637 const Tensor& num_segments = context->input(3);
638
639 OP_REQUIRES(
640 context, num_segments.shape().dims() == 0,
641 errors::InvalidArgument("num_segments should be a scalar, not shape ",
642 num_segments.shape().DebugString()));
643 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
644 OP_REQUIRES(context, output_rows >= 0,
645 errors::InvalidArgument("segment ids must be >= 0"));
646 }
647
648 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
649 errors::InvalidArgument("indices should be a vector."));
650 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
651 errors::InvalidArgument("segment_ids should be a vector."));
652
653 const int64 num_indices = indices.NumElements();
654 OP_REQUIRES(context, num_indices == segment_ids.NumElements(),
655 errors::InvalidArgument(
656 "segment_ids and indices should have same size."));
657
658 auto input_flat = input.flat_outer_dims<T>();
659 const int64 num_col = input_flat.dimension(1);
660 const auto indices_vec = indices.vec<Index>();
661 typedef int32 OutputRow;
662 const auto segment_vec = segment_ids.vec<OutputRow>();
663 // Note that the current implementation assumes that segment_vec values are
664 // sorted.
665 const OutputRow last_segment_id_plus_one =
666 num_indices > 0
667 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
668 : 0;
669 if (has_num_segments_) {
670 OP_REQUIRES(
671 context, output_rows >= last_segment_id_plus_one,
672 errors::InvalidArgument("segment ids must be < num_segments"));
673 } else {
674 output_rows = last_segment_id_plus_one;
675 }
676 OP_REQUIRES(context, output_rows >= 0,
677 errors::InvalidArgument("segment ids must be >= 0"));
678
679 TensorShape output_shape = input.shape();
680 output_shape.set_dim(0, output_rows);
681
682 // Note that we do not initialize the output buffer with a default value, so
683 // we need to explicitly set missing indices to the default value.
684 Tensor* output = nullptr;
685 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
686 if (num_indices == 0) {
687 if (output_rows > 0) {
688 output->flat_outer_dims<T>().setConstant(default_value_);
689 }
690 return;
691 }
692 OP_REQUIRES(context, output_rows > 0,
693 errors::InvalidArgument("segment ids must be >= 0"));
694 auto output_flat = output->flat_outer_dims<T>();
695
696 int64 start = 0, end = 1;
697 // Index from which the output is not initialized.
698 OutputRow uninitialized_index = 0;
699 OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
700
701 while (true) {
702 // We initialize next_index to 0 to avoid "warning: 'next_index' may be
703 // used uninitialized in this function" in the Mac build (since the
704 // compiler isn't smart enough to realize the code is safe).
705 OutputRow next_index = 0;
706 if (end < num_indices) {
707 next_index = internal::SubtleMustCopy(segment_vec(end));
708 if (out_index == next_index) {
709 ++end;
710 continue;
711 }
712 // We have a new segment here. Verify that the segment ids are growing.
713 OP_REQUIRES(context, out_index < next_index,
714 errors::InvalidArgument("segment ids are not increasing"));
715 }
716
717 OP_REQUIRES(
718 context, FastBoundsCheck(out_index, output_rows),
719 errors::InvalidArgument(
720 "Segment id ", out_index, " out of range [0, ", output_rows,
721 "), possibly because 'segment_ids' input is not sorted."));
722
723 // If there is a gap between two indices, we need to set that gap to the
724 // default value.
725 if (out_index > uninitialized_index) {
726 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
727 out_index - uninitialized_index, num_col);
728 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
729 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
730 gap_slice.setConstant(default_value_);
731 }
732
733 auto out = output_flat.template chip<0>(out_index);
734 const int bad_offset =
735 Reduce(input_flat, indices_vec, start, end - start, out);
736 OP_REQUIRES(context, bad_offset < 0,
737 errors::InvalidArgument(
738 "Bad: indices[", start + bad_offset,
739 "] == ", indices_vec(start + bad_offset),
740 " out of range [0, ", input_flat.dimension(0), ")"));
741
742 start = end;
743 ++end;
744 uninitialized_index = out_index + 1;
745 out_index = next_index;
746 if (end > num_indices) break;
747 }
748
749 // Fill the gap at the end with the default value.
750 if (uninitialized_index < output_rows) {
751 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
752 output_rows - uninitialized_index, num_col);
753 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
754 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
755 gap_slice.setConstant(default_value_);
756 }
757 }
758
759 private:
760 typedef int32 Index;
761
Reduce(const typename TTypes<T>::ConstMatrix & input_flat,const typename TTypes<Index>::ConstVec & indices_vec,int64 start,int64 num,Eigen::TensorChippingOp<0,typename TTypes<T>::Matrix> out)762 int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
763 const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
764 int64 num,
765 Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
766 #define INDEX(n, i) \
767 const auto index##n = indices_vec(start + (i)); \
768 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
769
770 #define L(n) input_flat.template chip<0>(index##n)
771
772 if (num == 1) {
773 INDEX(0, 0);
774 out = L(0);
775 } else {
776 int64 r = num % 8;
777 T m(1);
778 if (is_mean_ && (num < 10)) {
779 m = T(num);
780 }
781 if (is_sqrtn_ && (num < 10)) {
782 m = T(sqrt(num));
783 }
784 switch (r) {
785 case 2: {
786 INDEX(0, 0);
787 INDEX(1, 1);
788 out = (L(0) + L(1)) / m;
789 break;
790 }
791 case 3: {
792 INDEX(0, 0);
793 INDEX(1, 1);
794 INDEX(2, 2);
795 out = (L(0) + L(1) + L(2)) / m;
796 break;
797 }
798 case 4: {
799 INDEX(0, 0);
800 INDEX(1, 1);
801 INDEX(2, 2);
802 INDEX(3, 3);
803 out = (L(0) + L(1) + L(2) + L(3)) / m;
804 break;
805 }
806 case 5: {
807 INDEX(0, 0);
808 INDEX(1, 1);
809 INDEX(2, 2);
810 INDEX(3, 3);
811 INDEX(4, 4);
812 out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
813 break;
814 }
815 case 6: {
816 INDEX(0, 0);
817 INDEX(1, 1);
818 INDEX(2, 2);
819 INDEX(3, 3);
820 INDEX(4, 4);
821 INDEX(5, 5);
822 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
823 break;
824 }
825 case 7: {
826 INDEX(0, 0);
827 INDEX(1, 1);
828 INDEX(2, 2);
829 INDEX(3, 3);
830 INDEX(4, 4);
831 INDEX(5, 5);
832 INDEX(6, 6);
833 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
834 break;
835 }
836 case 0: {
837 INDEX(0, 0);
838 INDEX(1, 1);
839 INDEX(2, 2);
840 INDEX(3, 3);
841 INDEX(4, 4);
842 INDEX(5, 5);
843 INDEX(6, 6);
844 INDEX(7, 7);
845 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
846 r = 8;
847 break;
848 }
849 case 1: {
850 INDEX(0, 0);
851 INDEX(1, 1);
852 INDEX(2, 2);
853 INDEX(3, 3);
854 INDEX(4, 4);
855 INDEX(5, 5);
856 INDEX(6, 6);
857 INDEX(7, 7);
858 INDEX(8, 8);
859 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
860 m;
861 r = 9;
862 break;
863 }
864 }
865 for (; r < num; r += 8) {
866 INDEX(0, r);
867 INDEX(1, r + 1);
868 INDEX(2, r + 2);
869 INDEX(3, r + 3);
870 INDEX(4, r + 4);
871 INDEX(5, r + 5);
872 INDEX(6, r + 6);
873 INDEX(7, r + 7);
874 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
875 }
876 if (is_mean_ && num >= 10) {
877 out = out / static_cast<T>(num);
878 }
879 if (is_sqrtn_ && num >= 10) {
880 out = out / static_cast<T>(sqrt(num));
881 }
882 }
883
884 return -1;
885 #undef L
886 #undef INDEX
887 }
888
889 const bool is_mean_;
890 const bool is_sqrtn_;
891 const bool has_num_segments_;
892 const T default_value_;
893 };
894
895 template <typename Device, class T>
896 class SparseSegmentReductionMeanOp
897 : public SparseSegmentReductionOpBase<Device, T> {
898 public:
SparseSegmentReductionMeanOp(OpKernelConstruction * context)899 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
900 : SparseSegmentReductionOpBase<Device, T>(
901 context, true /*is_mean*/, false /*is_sqrtn*/,
902 false /* has_num_segments */, T(0) /* default_value */) {}
903 };
904
905 template <typename Device, class T>
906 class SparseSegmentReductionMeanWithNumSegmentsOp
907 : public SparseSegmentReductionOpBase<Device, T> {
908 public:
SparseSegmentReductionMeanWithNumSegmentsOp(OpKernelConstruction * context)909 explicit SparseSegmentReductionMeanWithNumSegmentsOp(
910 OpKernelConstruction* context)
911 : SparseSegmentReductionOpBase<Device, T>(
912 context, true /*is_mean*/, false /*is_sqrtn*/,
913 true /* has_num_segments */, T(0) /* default_value */) {}
914 };
915
916 template <typename Device, class T>
917 class SparseSegmentReductionSqrtNOp
918 : public SparseSegmentReductionOpBase<Device, T> {
919 public:
SparseSegmentReductionSqrtNOp(OpKernelConstruction * context)920 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
921 : SparseSegmentReductionOpBase<Device, T>(
922 context, false /*is_mean*/, true /*is_sqrtn*/,
923 false /* has_num_segments */, T(0) /* default_value */) {}
924 };
925
926 template <typename Device, class T>
927 class SparseSegmentReductionSqrtNWithNumSegmentsOp
928 : public SparseSegmentReductionOpBase<Device, T> {
929 public:
SparseSegmentReductionSqrtNWithNumSegmentsOp(OpKernelConstruction * context)930 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
931 OpKernelConstruction* context)
932 : SparseSegmentReductionOpBase<Device, T>(
933 context, false /*is_mean*/, true /*is_sqrtn*/,
934 true /* has_num_segments */, T(0) /* default_value */) {}
935 };
936
937 template <typename Device, class T>
938 class SparseSegmentReductionSumOp
939 : public SparseSegmentReductionOpBase<Device, T> {
940 public:
SparseSegmentReductionSumOp(OpKernelConstruction * context)941 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
942 : SparseSegmentReductionOpBase<Device, T>(
943 context, false /*is_mean*/, false /*is_sqrtn*/,
944 false /* has_num_segments */, T(0) /* default_value */) {}
945 };
946
947 template <typename Device, class T>
948 class SparseSegmentReductionSumWithNumSegmentsOp
949 : public SparseSegmentReductionOpBase<Device, T> {
950 public:
SparseSegmentReductionSumWithNumSegmentsOp(OpKernelConstruction * context)951 explicit SparseSegmentReductionSumWithNumSegmentsOp(
952 OpKernelConstruction* context)
953 : SparseSegmentReductionOpBase<Device, T>(
954 context, false /*is_mean*/, false /*is_sqrtn*/,
955 true /* has_num_segments */, T(0) /* default_value */) {}
956 };
957
958 #define REGISTER_CPU_SPARSE_KERNELS(type) \
959 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \
960 .Device(DEVICE_CPU) \
961 .TypeConstraint<type>("T") \
962 .TypeConstraint<int32>("Tidx"), \
963 SparseSegmentReductionSumOp<CPUDevice, type>); \
964 REGISTER_KERNEL_BUILDER( \
965 Name("SparseSegmentSumWithNumSegments") \
966 .Device(DEVICE_CPU) \
967 .TypeConstraint<type>("T") \
968 .TypeConstraint<int32>("Tidx"), \
969 SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
970 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
971 #undef REGISTER_CPU_SPARSE_KERNELS
972
973 #define REGISTER_CPU_SPARSE_KERNELS(type) \
974 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \
975 .Device(DEVICE_CPU) \
976 .TypeConstraint<type>("T") \
977 .TypeConstraint<int32>("Tidx"), \
978 SparseSegmentReductionMeanOp<CPUDevice, type>); \
979 REGISTER_KERNEL_BUILDER( \
980 Name("SparseSegmentMeanWithNumSegments") \
981 .Device(DEVICE_CPU) \
982 .TypeConstraint<type>("T") \
983 .TypeConstraint<int32>("Tidx"), \
984 SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
985 REGISTER_CPU_SPARSE_KERNELS(float);
986 REGISTER_CPU_SPARSE_KERNELS(double);
987 #undef REGISTER_CPU_SPARSE_KERNELS
988
989 #define REGISTER_CPU_SPARSE_KERNELS(type) \
990 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \
991 .Device(DEVICE_CPU) \
992 .TypeConstraint<type>("T") \
993 .TypeConstraint<int32>("Tidx"), \
994 SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
995 REGISTER_KERNEL_BUILDER( \
996 Name("SparseSegmentSqrtNWithNumSegments") \
997 .Device(DEVICE_CPU) \
998 .TypeConstraint<type>("T") \
999 .TypeConstraint<int32>("Tidx"), \
1000 SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
1001 REGISTER_CPU_SPARSE_KERNELS(float);
1002 REGISTER_CPU_SPARSE_KERNELS(double);
1003 #undef REGISTER_CPU_SPARSE_KERNELS
1004
1005 template <class T>
1006 class SparseSegmentGradOpBase : public OpKernel {
1007 public:
SparseSegmentGradOpBase(OpKernelConstruction * context,bool is_sqrtn)1008 explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
1009 : OpKernel(context), is_sqrtn_(is_sqrtn) {}
1010
Compute(OpKernelContext * context)1011 void Compute(OpKernelContext* context) override {
1012 const Tensor& input = context->input(0);
1013 const Tensor& indices = context->input(1);
1014 const Tensor& segment_ids = context->input(2);
1015 const Tensor& output_dim0 = context->input(3);
1016
1017 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
1018 errors::InvalidArgument("indices should be a vector."));
1019 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
1020 errors::InvalidArgument("segment_ids should be a vector."));
1021 OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
1022 errors::InvalidArgument("output_dim0 should be a scalar."));
1023
1024 const int64 N = indices.NumElements();
1025 OP_REQUIRES(context, N == segment_ids.NumElements(),
1026 errors::InvalidArgument(
1027 "segment_ids and indices should have same size."));
1028 typedef int32 SegmentId;
1029 const SegmentId M =
1030 internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
1031
1032 auto input_flat = input.flat_outer_dims<T>();
1033 typedef int32 Index;
1034 const auto indices_vec = indices.vec<Index>();
1035 const auto segment_vec = segment_ids.vec<SegmentId>();
1036
1037 TensorShape output_shape = input.shape();
1038 output_shape.set_dim(0, M);
1039 Tensor* output = nullptr;
1040 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1041 if (M == 0 || N == 0) return;
1042
1043 // Note that similar to SparseSegmentMean, we assume that segment_vec is
1044 // already sorted and has non-negative values.
1045 const SegmentId num_segments = input.dim_size(0);
1046 const SegmentId last_segment_id_plus_one =
1047 internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
1048 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
1049 errors::InvalidArgument("Invalid number of segments"));
1050
1051 // Compute scaling factors for input.
1052 std::vector<double> scaling(num_segments, 0.0);
1053 for (int64 i = 0; i < N; ++i) {
1054 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1055 OP_REQUIRES(
1056 context, FastBoundsCheck(idx, num_segments),
1057 errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1058 num_segments, ")."));
1059 scaling[idx] += 1;
1060 }
1061 for (size_t i = 0; i < scaling.size(); ++i) {
1062 if (is_sqrtn_) {
1063 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
1064 } else {
1065 scaling[i] = 1.0 / std::max(scaling[i], 1.0);
1066 }
1067 }
1068
1069 auto output_flat = output->flat_outer_dims<T>();
1070 output_flat.setZero();
1071 std::vector<bool> is_modified(M, false);
1072
1073 for (int64 i = 0; i < N; ++i) {
1074 const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
1075 OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
1076 errors::InvalidArgument("Index ", output_idx,
1077 " out of range [0, ", M, ")."));
1078
1079 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1080 OP_REQUIRES(
1081 context, FastBoundsCheck(idx, num_segments),
1082 errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1083 num_segments, ")."));
1084
1085 const T scale = static_cast<T>(scaling[idx]);
1086 if (is_modified[output_idx]) {
1087 if (scale == 1.0) {
1088 output_flat.template chip<0>(output_idx) +=
1089 input_flat.template chip<0>(idx);
1090 } else {
1091 output_flat.template chip<0>(output_idx) +=
1092 input_flat.template chip<0>(idx) * scale;
1093 }
1094 } else {
1095 if (scale == 1.0) {
1096 output_flat.template chip<0>(output_idx) =
1097 input_flat.template chip<0>(idx);
1098 } else {
1099 output_flat.template chip<0>(output_idx) =
1100 input_flat.template chip<0>(idx) * scale;
1101 }
1102 }
1103 is_modified[output_idx] = true;
1104 }
1105 }
1106
1107 private:
1108 const bool is_sqrtn_;
1109 };
1110
1111 template <class T>
1112 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
1113 public:
SparseSegmentMeanGradOp(OpKernelConstruction * context)1114 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
1115 : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
1116 };
1117
1118 template <class T>
1119 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
1120 public:
SparseSegmentSqrtNGradOp(OpKernelConstruction * context)1121 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
1122 : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
1123 };
1124
1125 #define REGISTER_CPU_SPARSE_KERNELS(type) \
1126 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \
1127 .Device(DEVICE_CPU) \
1128 .TypeConstraint<type>("T") \
1129 .TypeConstraint<int32>("Tidx"), \
1130 SparseSegmentMeanGradOp<type>);
1131 REGISTER_CPU_SPARSE_KERNELS(float);
1132 REGISTER_CPU_SPARSE_KERNELS(double);
1133 #undef REGISTER_CPU_SPARSE_KERNELS
1134
1135 #define REGISTER_CPU_SPARSE_KERNELS(type) \
1136 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \
1137 .Device(DEVICE_CPU) \
1138 .TypeConstraint<type>("T") \
1139 .TypeConstraint<int32>("Tidx"), \
1140 SparseSegmentSqrtNGradOp<type>);
1141 REGISTER_CPU_SPARSE_KERNELS(float);
1142 REGISTER_CPU_SPARSE_KERNELS(double);
1143 #undef REGISTER_CPU_SPARSE_KERNELS
1144 } // namespace tensorflow
1145