1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 19 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 20 21 #define EIGEN_USE_THREADS 22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 23 #define EIGEN_USE_GPU 24 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 25 26 #include <vector> 27 28 #include "third_party/eigen3/Eigen/Core" 29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 30 #include "tensorflow/core/framework/bounds_check.h" 31 #include "tensorflow/core/framework/numeric_op.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/framework/register_types.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_types.h" 36 #include "tensorflow/core/framework/tensor_util.h" 37 #include "tensorflow/core/framework/types.h" 38 #include "tensorflow/core/kernels/segment_reduction_ops.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/util/determinism.h" 42 #include "tensorflow/core/util/util.h" 43 44 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 45 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 46 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 47 48 #if GOOGLE_CUDA 49 #include "tensorflow/core/util/cuda_solvers.h" 50 #include "tensorflow/stream_executor/cuda/cuda_activation.h" 51 52 using stream_executor::cuda::ScopedActivateExecutorContext; 53 #elif TENSORFLOW_USE_ROCM 54 #include "tensorflow/core/platform/rocm.h" 55 #include "tensorflow/core/util/cuda_solvers.h" 56 using stream_executor::rocm::ScopedActivateExecutorContext; 57 #endif // GOOGLE_CUDA 58 59 namespace tensorflow { 60 61 typedef Eigen::ThreadPoolDevice CPUDevice; 62 typedef Eigen::GpuDevice GPUDevice; 63 64 namespace internal { 65 Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, 66 const Tensor& segment_ids); 67 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, 68 OpKernelContext* context, 69 const Tensor& data, 70 const Tensor& segment_ids, 71 const Tensor& num_segments); 72 Status ValidateSparseSegmentReduction(OpKernelContext* context, 73 const Tensor& input, 74 const Tensor& indices, 75 const Tensor& segment_ids, 76 bool has_num_segments); 77 } // namespace internal 78 79 // This operator handles reducing segments along the first dimension. 80 // See core/ops/math_ops.cc for more details. 81 template <typename Device, class T, class Index, typename Reducer, 82 int default_value> 83 class SegmentReductionOp : public OpKernel { 84 public: SegmentReductionOp(OpKernelConstruction * context)85 explicit SegmentReductionOp(OpKernelConstruction* context) 86 : OpKernel(context) {} 87 Compute(OpKernelContext * context)88 void Compute(OpKernelContext* context) override { 89 const Tensor& input = context->input(0); 90 const Tensor& segment_ids = context->input(1); 91 92 OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input, 93 segment_ids)); 94 95 const int64_t num_indices = segment_ids.NumElements(); 96 auto input_flat = input.flat_outer_dims<T>(); 97 const int64_t num_col = input_flat.dimension(1); 98 99 const auto segment_vec = segment_ids.vec<Index>(); 100 // Note that the current implementation assumes that segment_vec values are 101 // sorted. 102 const Index output_rows = 103 num_indices > 0 104 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 105 : 0; 106 OP_REQUIRES(context, output_rows >= 0, 107 errors::InvalidArgument("segment ids must be >= 0")); 108 109 OP_REQUIRES(context, input.dims() >= 1, 110 errors::InvalidArgument("Shape must be at least rank 1")); 111 112 TensorShape output_shape = input.shape(); 113 output_shape.set_dim(0, output_rows); 114 115 // Note that we do not initialize the output buffer with a default value, so 116 // we need to explicitly set missing indices to the default value. 117 Tensor* output = nullptr; 118 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 119 if (num_indices == 0) return; 120 OP_REQUIRES(context, output_rows > 0, 121 errors::InvalidArgument("segment ids must be >= 0")); 122 auto output_flat = output->flat_outer_dims<T>(); 123 124 #if !defined(EIGEN_HAS_INDEX_LIST) 125 Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce; 126 dims_to_reduce[0] = 0; 127 #else 128 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce; 129 #endif 130 Index start = 0, end = 1; 131 132 Index uninitialized_index = 0; // Index from which the output is not set. 133 Index out_index = internal::SubtleMustCopy(segment_vec(start)); 134 135 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it 136 // across threads. 137 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col); 138 while (end <= num_indices) { 139 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 140 // used uninitialized in this function" in the Mac build (since the 141 // compiler isn't smart enough to realize the code is safe). 142 Index next_index = 0; 143 if (end < num_indices) { 144 next_index = internal::SubtleMustCopy(segment_vec(end)); 145 if (out_index == next_index) { 146 ++end; 147 continue; 148 } 149 // We have a new segment here. Verify that the segment ids are growing. 150 OP_REQUIRES(context, out_index < next_index, 151 errors::InvalidArgument("segment ids are not increasing")); 152 } 153 154 // Process segment [start, end) 155 const T* in_slice_ptr = &input_flat(start, 0); 156 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 157 Eigen::Unaligned> 158 OutT; 159 160 OP_REQUIRES( 161 context, FastBoundsCheck(out_index, output_rows), 162 errors::InvalidArgument( 163 "Segment id ", out_index, " out of range [0, ", output_rows, 164 "), possibly because 'segment_ids' input is not sorted.")); 165 166 // If there is a gap between two indices, we need to set that gap to the 167 // default value. 168 if (out_index > uninitialized_index) { 169 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 170 out_index - uninitialized_index, num_col); 171 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 172 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 173 gap_slice.setConstant(T(default_value)); 174 } 175 176 T* out_slice_ptr = &output_flat(out_index, 0); 177 OutT out_slice(out_slice_ptr, out_slice_shape); 178 // We don't use out_slice.device(context->eigen_device<Device>) 179 // because these pieces of work are likely to be very small and 180 // the context switching overhead dwarfs any benefit we get from 181 // using another thread to do this work. 182 if (start == end - 1) { 183 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, 184 Eigen::Unaligned> 185 InT; 186 InT in_slice(in_slice_ptr, out_slice_shape); 187 out_slice = in_slice; 188 } else { 189 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start, 190 num_col); 191 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 192 Eigen::Unaligned> 193 InT; 194 InT in_slice(in_slice_ptr, in_slice_shape); 195 196 out_slice = in_slice.reduce(dims_to_reduce, Reducer()); 197 } 198 if (end >= num_indices) break; 199 start = end; 200 ++end; 201 uninitialized_index = out_index + 1; 202 out_index = next_index; 203 } 204 } 205 }; 206 207 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 208 209 // SegmentReductionGPUOp is a segment reduction operator implemented for GPU 210 // only. 211 // TODO: This implementation of SegmentReductionGPUOp is sometimes slower than 212 // its unsorted counterpart (mostly when problem size is small). 213 // This is due to the following two main reasons and a cost-effective way 214 // to resolve these problems is desirable. 215 // 1. Sorted segment reduction requires a memory transfer from device to host 216 // in order to know the size of the output dimension whereas unsorted 217 // segment reduction receives the size of the output dimension as an input 218 // parameter. 219 // 2. Sorted segment reduction is essentially a tiled version of unsorted 220 // segment reduction and therefore such optimization comes at an inherent 221 // cost. However such cost may not be justified when the problem size is 222 // small. When to use the tiled version or the untiled version depends on 223 // many factors including data alignments, ratio of calculation to memory 224 // traffic and obviously, the problem sizes. 225 template <class T, class Index, class SegmentReductionFunctor> 226 class SegmentReductionGPUOp : public AsyncOpKernel { 227 public: SegmentReductionGPUOp(OpKernelConstruction * context)228 explicit SegmentReductionGPUOp(OpKernelConstruction* context) 229 : AsyncOpKernel(context) {} 230 ComputeAsync(OpKernelContext * context,DoneCallback done)231 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 232 const Tensor& input = context->input(0); 233 const Tensor& segment_ids = context->input(1); 234 235 OP_REQUIRES_ASYNC( 236 context, TensorShapeUtils::IsVector(segment_ids.shape()), 237 errors::InvalidArgument("segment_ids should be a vector."), done); 238 239 OP_REQUIRES_ASYNC(context, input.dims() >= 1, 240 errors::InvalidArgument("Shape must be at least rank 1"), 241 done); 242 243 const int64_t num_indices = segment_ids.NumElements(); 244 OP_REQUIRES_ASYNC( 245 context, num_indices == input.dim_size(0), 246 errors::InvalidArgument( 247 "segment_ids should be the same size as dimension 0 of" 248 " input."), 249 done); 250 251 if (num_indices == 0) { 252 TensorShape output_shape = input.shape(); 253 output_shape.set_dim(0, 0); 254 255 Tensor* output = nullptr; 256 OP_REQUIRES_OK_ASYNC( 257 context, context->allocate_output(0, output_shape, &output), done); 258 done(); 259 return; 260 } 261 262 se::DeviceMemoryBase output_rows_device( 263 const_cast<Tensor&>(segment_ids).template flat<Index>().data() + 264 (num_indices - 1)); 265 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); 266 267 auto stream = context->op_device_context()->stream(); 268 OP_REQUIRES_ASYNC( 269 context, 270 stream 271 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, 272 sizeof(Index)) 273 .ok(), 274 errors::Internal(type_string() + 275 ": failed to copy output_rows from device"), 276 done); 277 278 SegmentReductionFunctor functor_; 279 auto create_and_check_output = [context, output_rows_host, &input, 280 &segment_ids, &functor_, done]() { 281 // Ensure that within the callback, the proper GPU settings are 282 // configured. 283 auto stream = context->op_device_context()->stream(); 284 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 285 286 Index output_rows = *output_rows_host.data(); 287 output_rows++; 288 OP_REQUIRES_ASYNC(context, output_rows > 0, 289 errors::InvalidArgument("segment ids must be >= 0"), 290 done); 291 292 TensorShape output_shape = input.shape(); 293 output_shape.set_dim(0, output_rows); 294 295 Tensor* output = nullptr; 296 OP_REQUIRES_OK_ASYNC( 297 context, context->allocate_output(0, output_shape, &output), done); 298 299 // The determinism check is here, rather than inside the functor (as it is 300 // for the unsorted segment reduction ops) because the done callback 301 // (required for OP_REQUIRES_ASYNC) is not available inside the functor. 302 bool determinism_requirement_met = 303 SegmentReductionFunctor::atomic_reduction_is_associative || 304 !OpDeterminismRequired() || 305 DisableSegmentReductionOpDeterminismExceptions(); 306 OP_REQUIRES_ASYNC( 307 context, determinism_requirement_met, 308 errors::Unimplemented( 309 "Deterministic GPU implementation of sorted segment reduction op" 310 " not available."), 311 done); 312 313 auto output_flat = output->flat_outer_dims<T>(); 314 auto data_ptr = input.template flat<T>().data(); 315 auto segment_flat = segment_ids.flat<Index>(); 316 functor_(context, context->eigen_device<GPUDevice>(), output_rows, 317 segment_ids.shape(), segment_flat, input.NumElements(), data_ptr, 318 output_flat); 319 320 done(); 321 }; 322 323 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 324 stream, create_and_check_output); 325 } 326 }; 327 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 328 329 // ____________________________________________________________________________ 330 // Unsorted segment reduction ops. 331 332 namespace functor { 333 334 // The ReductionFunctor implementation for CPU. 335 template <typename T, typename Index, typename InitialValueF, 336 typename ReductionF> 337 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> { 338 void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, 339 typename TTypes<Index>::ConstFlat segment_ids, 340 typename TTypes<T, 2>::ConstTensor data, 341 typename TTypes<T, 2>::Tensor output) { 342 output.setConstant(InitialValueF()()); 343 if (data.size() == 0) { 344 return; 345 } 346 const int64_t N = segment_ids.dimension(0); 347 const int64_t num_segments = output.dimension(0); 348 ReductionF reduction; 349 for (int64_t i = 0; i < N; ++i) { 350 Index j = internal::SubtleMustCopy(segment_ids(i)); 351 if (j < 0) { 352 continue; 353 } 354 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), 355 errors::InvalidArgument( 356 "segment_ids", SliceDebugString(segment_ids_shape, i), 357 " = ", j, " is out of range [0, ", num_segments, ")")); 358 reduction(data.template chip<0>(i), output.template chip<0>(j)); 359 } 360 } 361 }; 362 363 template <typename T> 364 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>; 365 366 template <typename T> 367 using constMatrixChip = 368 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>; 369 370 // reduction functors 371 template <typename T> 372 struct SumOp { 373 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 374 output += data; 375 } 376 }; 377 378 template <typename T> 379 struct MaxOp { 380 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 381 output = data.cwiseMax(output); 382 } 383 }; 384 385 template <typename T> 386 struct MinOp { 387 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 388 output = data.cwiseMin(output); 389 } 390 }; 391 392 template <typename T> 393 struct ProdOp { 394 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 395 output *= data; 396 } 397 }; 398 } // namespace functor 399 400 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor 401 // is the device specific implementation of the reduction. These device 402 // specific implementations are templated themselves with the corresponding 403 // initial value functors and reduction functors. 404 template <typename T, typename Index, typename DeviceReductionFunctor> 405 class UnsortedSegmentReductionOp : public OpKernel { 406 public: 407 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) 408 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} 409 410 void Compute(OpKernelContext* context) override { 411 const Tensor& data = context->input(0); 412 const Tensor& segment_ids = context->input(1); 413 const Tensor& num_segments = context->input(2); 414 OP_REQUIRES_OK(context, 415 internal::ValidateUnsortedSegmentReduction( 416 this, context, data, segment_ids, num_segments)); 417 const auto segment_flat = segment_ids.flat<Index>(); 418 const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64>( 419 num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()() 420 : num_segments.scalar<int64>()())); 421 OP_REQUIRES(context, output_rows >= 0, 422 errors::InvalidArgument("Input num_segments == ", output_rows, 423 " must not be negative.")); 424 TensorShape output_shape; 425 output_shape.AddDim(output_rows); 426 for (int i = segment_ids.dims(); i < data.dims(); i++) { 427 output_shape.AddDim(data.dim_size(i)); 428 } 429 Tensor* output = nullptr; 430 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 431 auto output_flat = output->flat_outer_dims<T>(); 432 auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1); 433 reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat, 434 output_flat); 435 } 436 437 protected: 438 DeviceReductionFunctor reduction_functor_; 439 }; 440 441 // ____________________________________________________________________________ 442 // Sparse segment reduction ops. 443 444 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented 445 // by two dense tensors, one containing the data, and the other containing 446 // indices into the data. 447 // 448 // The template parameters are: 449 // * Device: An Eigen device object, on which the kernel will execute. 450 // * T: The value type. 451 // * Index: The element type of the indices tensor (int32 or int64). 452 // * SegmentId: The element type of the segment_ids tensor (int32 or int64). 453 template <typename Device, class T, typename Index, typename SegmentId> 454 class SparseSegmentReductionOpBase : public OpKernel { 455 public: 456 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 457 bool is_mean, bool is_sqrtn, 458 bool has_num_segments, T default_value) 459 : OpKernel(context), 460 is_mean_(is_mean), 461 is_sqrtn_(is_sqrtn), 462 has_num_segments_(has_num_segments), 463 default_value_(default_value) {} 464 465 void Compute(OpKernelContext* context) override { 466 const Tensor& input = context->input(0); 467 const Tensor& indices = context->input(1); 468 const Tensor& segment_ids = context->input(2); 469 470 OP_REQUIRES_OK( 471 context, internal::ValidateSparseSegmentReduction( 472 context, input, indices, segment_ids, has_num_segments_)); 473 474 Index output_rows = -1; 475 if (has_num_segments_) { 476 const Tensor& num_segments = context->input(3); 477 // Note that there is a Tnumsegments parameter on the op, but it is not 478 // plumbed through to here and so always takes its default value of int32. 479 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); 480 } 481 const int64_t num_indices = indices.NumElements(); 482 483 auto input_flat = input.flat_outer_dims<T>(); 484 const int64_t num_col = input_flat.dimension(1); 485 const auto indices_vec = indices.vec<Index>(); 486 const auto segment_vec = segment_ids.vec<SegmentId>(); 487 // Note that the current implementation assumes that segment_vec values are 488 // sorted. 489 const SegmentId last_segment_id_plus_one = 490 num_indices > 0 491 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 492 : 0; 493 if (has_num_segments_) { 494 OP_REQUIRES( 495 context, output_rows >= last_segment_id_plus_one, 496 errors::InvalidArgument("segment ids must be < num_segments")); 497 } else { 498 output_rows = last_segment_id_plus_one; 499 } 500 OP_REQUIRES(context, output_rows >= 0, 501 errors::InvalidArgument("segment ids must be >= 0")); 502 503 TensorShape output_shape = input.shape(); 504 output_shape.set_dim(0, output_rows); 505 506 // Note that we do not initialize the output buffer with a default value, so 507 // we need to explicitly set missing indices to the default value. 508 Tensor* output = nullptr; 509 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 510 if (num_indices == 0) { 511 if (output_rows > 0) { 512 output->flat_outer_dims<T>().setConstant(default_value_); 513 } 514 return; 515 } 516 OP_REQUIRES(context, output_rows > 0, 517 errors::InvalidArgument("segment ids must be >= 0")); 518 auto output_flat = output->flat_outer_dims<T>(); 519 520 Tensor temp; 521 if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) { 522 temp = tensorflow::Tensor(DT_FLOAT, output_shape); 523 } 524 auto temp_flat = temp.flat_outer_dims<float>(); 525 526 int64_t start = 0, end = 1; 527 // Index from which the output is not initialized. 528 SegmentId uninitialized_index = 0; 529 SegmentId out_index = internal::SubtleMustCopy(segment_vec(start)); 530 531 while (true) { 532 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 533 // used uninitialized in this function" in the Mac build (since the 534 // compiler isn't smart enough to realize the code is safe). 535 SegmentId next_index = 0; 536 if (end < num_indices) { 537 next_index = internal::SubtleMustCopy(segment_vec(end)); 538 if (out_index == next_index) { 539 ++end; 540 continue; 541 } 542 // We have a new segment here. Verify that the segment ids are growing. 543 OP_REQUIRES(context, out_index < next_index, 544 errors::InvalidArgument("segment ids are not increasing")); 545 } 546 547 OP_REQUIRES( 548 context, FastBoundsCheck(out_index, output_rows), 549 errors::InvalidArgument( 550 "Segment id ", out_index, " out of range [0, ", output_rows, 551 "), possibly because 'segment_ids' input is not sorted.")); 552 553 // If there is a gap between two indices, we need to set that gap to the 554 // default value. 555 if (out_index > uninitialized_index) { 556 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 557 out_index - uninitialized_index, num_col); 558 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 559 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 560 gap_slice.setConstant(default_value_); 561 } 562 563 auto out = output_flat.template chip<0>(out_index); 564 auto temp = temp_flat.template chip<0>(out_index); 565 const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start, 566 end - start, out, temp); 567 OP_REQUIRES(context, bad_offset < 0, 568 errors::InvalidArgument( 569 "Bad: indices[", start + bad_offset, 570 "] == ", indices_vec(start + bad_offset), 571 " out of range [0, ", input_flat.dimension(0), ")")); 572 573 start = end; 574 ++end; 575 uninitialized_index = out_index + 1; 576 out_index = next_index; 577 if (end > num_indices) break; 578 } 579 580 // Fill the gap at the end with the default value. 581 if (uninitialized_index < output_rows) { 582 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 583 output_rows - uninitialized_index, num_col); 584 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 585 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 586 gap_slice.setConstant(default_value_); 587 } 588 } 589 590 private: 591 template <typename Tin> 592 using EnableIfBfloat16OrHalf = 593 typename std::enable_if<std::is_same<Tin, bfloat16>::value || 594 std::is_same<Tin, Eigen::half>::value, 595 int>::type; 596 template <typename Tin> 597 using EnableIfNotBfloat16OrHalf = 598 typename std::enable_if<!std::is_same<Tin, bfloat16>::value && 599 !std::is_same<Tin, Eigen::half>::value, 600 int>::type; 601 602 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> 603 EIGEN_ALWAYS_INLINE auto fetch_val( 604 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { 605 return input_flat.template chip<0>(index); 606 } 607 608 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> 609 EIGEN_ALWAYS_INLINE auto fetch_val( 610 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { 611 return input_flat.template chip<0>(index).template cast<float>(); 612 } 613 614 template <typename Tout> 615 EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) { 616 Tout m(1); 617 if (is_mean_ && (num < 10)) { 618 m = Tout(num); 619 } 620 if (is_sqrtn_ && (num < 10)) { 621 m = Tout(sqrt(num)); 622 } 623 return Tout(1) / m; 624 } 625 626 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> 627 int64 Reduce( 628 const typename TTypes<Tin>::ConstMatrix& input_flat, 629 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 630 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, 631 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { 632 return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num, 633 out, get_scaling_factor<Tin>(num)); 634 } 635 636 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> 637 int64 Reduce( 638 const typename TTypes<Tin>::ConstMatrix& input_flat, 639 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 640 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, 641 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { 642 int64_t res = 643 ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num, 644 temp, get_scaling_factor<float>(num)); 645 out = temp.template cast<Tin>(); 646 return res; 647 } 648 649 template <typename Tin, typename Tindex, typename Tout> 650 int64 ReduceImpl( 651 const typename TTypes<Tin>::ConstMatrix& input_flat, 652 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 653 int64_t num, 654 Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out, 655 const Tout scaling_factor) { 656 #define INDEX(n, i) \ 657 const auto index##n = indices_vec(start + (i)); \ 658 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); 659 660 #define L(n) fetch_val<Tin, Tindex>(input_flat, index##n) 661 662 if (num == 1) { 663 INDEX(0, 0); 664 out = L(0); 665 } else { 666 int64_t r = num & 7; 667 switch (r) { 668 case 2: { 669 INDEX(0, 0); 670 INDEX(1, 1); 671 out = (L(0) + L(1)) * scaling_factor; 672 break; 673 } 674 case 3: { 675 INDEX(0, 0); 676 INDEX(1, 1); 677 INDEX(2, 2); 678 out = (L(0) + L(1) + L(2)) * scaling_factor; 679 break; 680 } 681 case 4: { 682 INDEX(0, 0); 683 INDEX(1, 1); 684 INDEX(2, 2); 685 INDEX(3, 3); 686 out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor; 687 break; 688 } 689 case 5: { 690 INDEX(0, 0); 691 INDEX(1, 1); 692 INDEX(2, 2); 693 INDEX(3, 3); 694 INDEX(4, 4); 695 out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor; 696 break; 697 } 698 case 6: { 699 INDEX(0, 0); 700 INDEX(1, 1); 701 INDEX(2, 2); 702 INDEX(3, 3); 703 INDEX(4, 4); 704 INDEX(5, 5); 705 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor; 706 break; 707 } 708 case 7: { 709 INDEX(0, 0); 710 INDEX(1, 1); 711 INDEX(2, 2); 712 INDEX(3, 3); 713 INDEX(4, 4); 714 INDEX(5, 5); 715 INDEX(6, 6); 716 out = 717 (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor; 718 break; 719 } 720 case 0: { 721 INDEX(0, 0); 722 INDEX(1, 1); 723 INDEX(2, 2); 724 INDEX(3, 3); 725 INDEX(4, 4); 726 INDEX(5, 5); 727 INDEX(6, 6); 728 INDEX(7, 7); 729 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) * 730 scaling_factor; 731 r = 8; 732 break; 733 } 734 case 1: { 735 INDEX(0, 0); 736 INDEX(1, 1); 737 INDEX(2, 2); 738 INDEX(3, 3); 739 INDEX(4, 4); 740 INDEX(5, 5); 741 INDEX(6, 6); 742 INDEX(7, 7); 743 INDEX(8, 8); 744 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) * 745 scaling_factor; 746 r = 9; 747 break; 748 } 749 } 750 for (; r < num; r += 8) { 751 INDEX(0, r); 752 INDEX(1, r + 1); 753 INDEX(2, r + 2); 754 INDEX(3, r + 3); 755 INDEX(4, r + 4); 756 INDEX(5, r + 5); 757 INDEX(6, r + 6); 758 INDEX(7, r + 7); 759 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); 760 } 761 if (is_mean_ && num >= 10) { 762 out = out / static_cast<Tout>(num); 763 } 764 if (is_sqrtn_ && num >= 10) { 765 out = out / static_cast<Tout>(sqrt(num)); 766 } 767 } 768 769 return -1; 770 #undef L 771 #undef INDEX 772 } 773 774 const bool is_mean_; 775 const bool is_sqrtn_; 776 const bool has_num_segments_; 777 const T default_value_; 778 }; 779 780 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 781 782 // Specialization for GPU. Must be Async because may need to wait for a host to 783 // device memcpy before allocating output. 784 template <class T, typename Index, typename SegmentId> 785 class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId> 786 : public AsyncOpKernel { 787 public: 788 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 789 bool is_mean, bool is_sqrtn, 790 bool has_num_segments, T default_value) 791 : AsyncOpKernel(context), 792 is_mean_(is_mean), 793 is_sqrtn_(is_sqrtn), 794 has_num_segments_(has_num_segments), 795 default_value_(default_value) {} 796 797 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 798 const Tensor& input = context->input(0); 799 const Tensor& indices = context->input(1); 800 const Tensor& segment_ids = context->input(2); 801 802 OP_REQUIRES_OK_ASYNC( 803 context, 804 internal::ValidateSparseSegmentReduction( 805 context, input, indices, segment_ids, has_num_segments_), 806 done); 807 808 ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true); 809 810 auto create_and_check_output = [this, context, input, indices, segment_ids, 811 last_segment_id_host, done]() { 812 // Ensure that within the callback, the proper GPU settings are 813 // configured. 814 auto stream = context->op_device_context()->stream(); 815 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 816 817 SegmentId last_segment_id = *last_segment_id_host.data(); 818 SegmentId output_rows = last_segment_id + 1; 819 OP_REQUIRES_ASYNC(context, output_rows > 0, 820 errors::InvalidArgument("segment ids must be >= 0"), 821 done); 822 823 TensorShape output_shape = input.shape(); 824 output_shape.set_dim(0, output_rows); 825 826 Tensor* output = nullptr; 827 OP_REQUIRES_OK_ASYNC( 828 context, context->allocate_output(0, output_shape, &output), done); 829 830 auto input_flat = input.flat_outer_dims<T>(); 831 const auto indices_vec = indices.vec<Index>(); 832 const auto segment_ids_vec = segment_ids.vec<SegmentId>(); 833 auto output_flat = output->flat_outer_dims<T>(); 834 835 functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor; 836 OP_REQUIRES_OK_ASYNC( 837 context, 838 functor(context, is_mean_, is_sqrtn_, default_value_, input_flat, 839 indices_vec, segment_ids_vec, output_flat), 840 done); 841 done(); 842 }; 843 844 if (has_num_segments_) { 845 // No need to do any device to host memcpy, just compute synchronously. 846 const Tensor& num_segments_t = context->input(3); 847 SegmentId num_segments = 848 internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32 849 ? num_segments_t.scalar<int32>()() 850 : num_segments_t.scalar<int64>()()); 851 *last_segment_id_host.mutable_data() = num_segments - 1; 852 create_and_check_output(); 853 } else { 854 const int64_t num_indices = indices.NumElements(); 855 // Need to copy last element of segment_ids from device to host, and then 856 // asynchronously allocate the output and finish the computation. 857 se::DeviceMemoryBase last_segment_id_device( 858 const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() + 859 (num_indices - 1)); 860 auto stream = context->op_device_context()->stream(); 861 OP_REQUIRES_ASYNC( 862 context, 863 stream 864 ->ThenMemcpy(last_segment_id_host.mutable_data(), 865 last_segment_id_device, sizeof(SegmentId)) 866 .ok(), 867 errors::Internal(type_string() + 868 ": failed to copy last_segment_id from device"), 869 done); 870 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 871 stream, create_and_check_output); 872 } 873 } 874 875 private: 876 const bool is_mean_; 877 const bool is_sqrtn_; 878 const bool has_num_segments_; 879 const T default_value_; 880 }; 881 882 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 883 884 template <typename Device, class T, typename Index, typename SegmentId> 885 class SparseSegmentReductionMeanOp 886 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 887 public: 888 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) 889 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 890 context, true /*is_mean*/, false /*is_sqrtn*/, 891 false /* has_num_segments */, T(0) /* default_value */) {} 892 }; 893 894 template <typename Device, class T, typename Index, typename SegmentId> 895 class SparseSegmentReductionMeanWithNumSegmentsOp 896 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 897 public: 898 explicit SparseSegmentReductionMeanWithNumSegmentsOp( 899 OpKernelConstruction* context) 900 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 901 context, true /*is_mean*/, false /*is_sqrtn*/, 902 true /* has_num_segments */, T(0) /* default_value */) {} 903 }; 904 905 template <typename Device, class T, typename Index, typename SegmentId> 906 class SparseSegmentReductionSqrtNOp 907 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 908 public: 909 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) 910 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 911 context, false /*is_mean*/, true /*is_sqrtn*/, 912 false /* has_num_segments */, T(0) /* default_value */) {} 913 }; 914 915 template <typename Device, class T, typename Index, typename SegmentId> 916 class SparseSegmentReductionSqrtNWithNumSegmentsOp 917 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 918 public: 919 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( 920 OpKernelConstruction* context) 921 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 922 context, false /*is_mean*/, true /*is_sqrtn*/, 923 true /* has_num_segments */, T(0) /* default_value */) {} 924 }; 925 926 template <typename Device, class T, typename Index, typename SegmentId> 927 class SparseSegmentReductionSumOp 928 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 929 public: 930 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) 931 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 932 context, false /*is_mean*/, false /*is_sqrtn*/, 933 false /* has_num_segments */, T(0) /* default_value */) {} 934 }; 935 936 template <typename Device, class T, typename Index, typename SegmentId> 937 class SparseSegmentReductionSumWithNumSegmentsOp 938 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 939 public: 940 explicit SparseSegmentReductionSumWithNumSegmentsOp( 941 OpKernelConstruction* context) 942 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 943 context, false /*is_mean*/, false /*is_sqrtn*/, 944 true /* has_num_segments */, T(0) /* default_value */) {} 945 }; 946 947 namespace functor { 948 949 template <typename T, typename Index, typename SegmentId> 950 struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> { 951 void operator()(OpKernelContext* context, 952 SparseSegmentReductionOperation operation, 953 typename TTypes<T>::ConstMatrix input_flat, 954 typename TTypes<Index>::ConstVec indices_vec, 955 typename TTypes<SegmentId>::ConstVec segment_vec, 956 typename TTypes<T>::Matrix output_flat) { 957 const int64_t N = indices_vec.size(); 958 const SegmentId M = output_flat.dimension(0); 959 960 // Note that similar to SparseSegmentMean, we assume that segment_vec is 961 // already sorted and has non-negative values. 962 const SegmentId num_segments = input_flat.dimension(0); 963 const SegmentId last_segment_id_plus_one = 964 internal::SubtleMustCopy(segment_vec(N - 1)) + 1; 965 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, 966 errors::InvalidArgument("Invalid number of segments")); 967 968 // Compute scaling factors for input. 969 std::vector<double> scaling( 970 (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments), 971 0.0); 972 if (operation != SparseSegmentReductionOperation::kSum) { 973 for (int64_t i = 0; i < N; ++i) { 974 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 975 OP_REQUIRES( 976 context, FastBoundsCheck(idx, num_segments), 977 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 978 num_segments, ").")); 979 scaling[idx] += 1; 980 } 981 for (size_t i = 0; i < scaling.size(); ++i) { 982 switch (operation) { 983 case SparseSegmentReductionOperation::kSum: { 984 OP_REQUIRES( 985 context, false, 986 errors::Internal( 987 "Should not happen: sum inside SparseSegmentReductionOp " 988 "scaling generation.")); 989 } 990 case SparseSegmentReductionOperation::kMean: { 991 scaling[i] = 1.0 / std::max(scaling[i], 1.0); 992 break; 993 } 994 case SparseSegmentReductionOperation::kSqrtN: { 995 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); 996 break; 997 } 998 // No default to get compiler warnings for missing cases. 999 } 1000 } 1001 } 1002 1003 output_flat.setZero(); 1004 std::vector<bool> is_modified(M, false); 1005 1006 for (int64_t i = 0; i < N; ++i) { 1007 const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); 1008 OP_REQUIRES(context, FastBoundsCheck(output_idx, M), 1009 errors::InvalidArgument("Index ", output_idx, 1010 " out of range [0, ", M, ").")); 1011 1012 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1013 OP_REQUIRES( 1014 context, FastBoundsCheck(idx, num_segments), 1015 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1016 num_segments, ").")); 1017 1018 const T scale = (operation == SparseSegmentReductionOperation::kSum 1019 ? static_cast<T>(1) 1020 : static_cast<T>(scaling[idx])); 1021 if (is_modified[output_idx]) { 1022 if (scale == 1.0) { 1023 output_flat.template chip<0>(output_idx) += 1024 input_flat.template chip<0>(idx); 1025 } else { 1026 output_flat.template chip<0>(output_idx) += 1027 input_flat.template chip<0>(idx) * scale; 1028 } 1029 } else { 1030 if (scale == 1.0) { 1031 output_flat.template chip<0>(output_idx) = 1032 input_flat.template chip<0>(idx); 1033 } else { 1034 output_flat.template chip<0>(output_idx) = 1035 input_flat.template chip<0>(idx) * scale; 1036 } 1037 } 1038 is_modified[output_idx] = true; 1039 } 1040 } 1041 }; 1042 1043 } // namespace functor 1044 1045 // Implements the common logic for the gradients of SparseSegmentReduction 1046 // kernels. 1047 // 1048 // The template parameters are: 1049 // * Device: An Eigen device object, on which the kernel will execute. 1050 // * T: The value type. 1051 // * Index: The element type of the indices tensor (int32 or int64). 1052 // * SegmentId: The element type of the segment_ids tensor (int32 or int64). 1053 template <typename Device, class T, typename Index, typename SegmentId> 1054 class SparseSegmentGradOpBase : public OpKernel { 1055 public: 1056 explicit SparseSegmentGradOpBase(OpKernelConstruction* context, 1057 SparseSegmentReductionOperation operation) 1058 : OpKernel(context), operation_(operation) {} 1059 1060 void Compute(OpKernelContext* context) override { 1061 const Tensor& input = context->input(0); 1062 const Tensor& indices = context->input(1); 1063 const Tensor& segment_ids = context->input(2); 1064 const Tensor& output_dim0 = context->input(3); 1065 1066 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 1067 errors::InvalidArgument("indices should be a vector.")); 1068 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 1069 errors::InvalidArgument("segment_ids should be a vector.")); 1070 OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()), 1071 errors::InvalidArgument("output_dim0 should be a scalar.")); 1072 1073 const int64_t N = indices.NumElements(); 1074 OP_REQUIRES(context, N == segment_ids.NumElements(), 1075 errors::InvalidArgument( 1076 "segment_ids and indices should have same size.")); 1077 const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()()); 1078 1079 auto input_flat = input.flat_outer_dims<T>(); 1080 const auto indices_vec = indices.vec<Index>(); 1081 const auto segment_vec = segment_ids.vec<SegmentId>(); 1082 1083 TensorShape output_shape = input.shape(); 1084 output_shape.set_dim(0, M); 1085 Tensor* output = nullptr; 1086 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 1087 if (M == 0 || N == 0) return; 1088 1089 auto output_flat = output->flat_outer_dims<T>(); 1090 functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()( 1091 context, operation_, input_flat, indices_vec, segment_vec, output_flat); 1092 } 1093 1094 private: 1095 const SparseSegmentReductionOperation operation_; 1096 }; 1097 1098 template <typename Device, class T, typename Index, typename SegmentId> 1099 class SparseSegmentSumGradOp 1100 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1101 public: 1102 explicit SparseSegmentSumGradOp(OpKernelConstruction* context) 1103 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1104 context, SparseSegmentReductionOperation::kSum) {} 1105 }; 1106 1107 template <typename Device, class T, typename Index, typename SegmentId> 1108 class SparseSegmentMeanGradOp 1109 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1110 public: 1111 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) 1112 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1113 context, SparseSegmentReductionOperation::kMean) {} 1114 }; 1115 1116 template <typename Device, class T, typename Index, typename SegmentId> 1117 class SparseSegmentSqrtNGradOp 1118 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1119 public: 1120 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) 1121 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1122 context, SparseSegmentReductionOperation::kSqrtN) {} 1123 }; 1124 1125 } // namespace tensorflow 1126 1127 #endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 1128