1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
16 #define TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19
20 #define EIGEN_USE_GPU
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/gpu_prim.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27
28 namespace tensorflow {
29
30 namespace detail {
31
32 template <typename T>
RangeInitKernel(const T start,const T delta,const T size,T * out)33 __global__ void RangeInitKernel(const T start, const T delta, const T size,
34 T* out) {
35 GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
36 }
37
38 // Initialize out with range start, start + delta, start + 2 * delta, ...
39 template <typename T>
RangeInit(const Eigen::GpuDevice & d,const T start,const T delta,const T size,T * out)40 Status RangeInit(const Eigen::GpuDevice& d, const T start, const T delta,
41 const T size, T* out) {
42 if (size == 0) return Status::OK();
43 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
44 return GpuLaunchKernel(RangeInitKernel<T>, config.block_count,
45 config.thread_per_block, 0, d.stream(), start, delta,
46 size, out);
47 }
48
49 } // namespace detail
50
51 // Computes keys_out = sorted(keys_in), and indices_out = argsort(keys_in).
52 // If keys_out is not required, it can be set to nullptr.
53 // If indices_in is nullptr, the range of input indices [0, size) will be used.
54 template <typename Tkey, typename Tindex>
55 Status GpuRadixSort(OpKernelContext* context, int size, const Tkey* keys_in,
56 Tkey* keys_out, // Optional
57 const Tindex* indices_in, // Optional
58 Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
59 if (size == 0) return Status::OK();
60 // Allocate temporary inputs/outputs if necessary.
61 Tensor tmp_indices_in;
62 if (!indices_in) {
63 TF_RETURN_IF_ERROR(context->allocate_temp(
64 DataTypeToEnum<Tindex>::value, TensorShape({size}), &tmp_indices_in));
65 Tindex* mutable_indices_in = tmp_indices_in.flat<Tindex>().data();
66 indices_in = mutable_indices_in;
67 const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
68 // Initialize indices_in to the input index range.
69 TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
70 Tindex(size), mutable_indices_in));
71 }
72 Tensor tmp_keys_out;
73 if (!keys_out) {
74 TF_RETURN_IF_ERROR(context->allocate_temp(
75 DataTypeToEnum<Tkey>::value, TensorShape({size}), &tmp_keys_out));
76 keys_out = tmp_keys_out.flat<Tkey>().data();
77 }
78 // Determine temporary device storage requirements.
79 Tensor temp_storage;
80 size_t temp_storage_bytes = 0;
81 const auto& cu_stream = GetGpuStream(context);
82 auto err = gpuprim::DeviceRadixSort::SortPairs(
83 nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out,
84 size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream);
85 if (err != 0) {
86 return errors::Internal(
87 "Failed to launch gpuprim::DeviceRadixSort::SortPairs to calculate "
88 "temp_storage_bytes, status: ",
89 cudaGetErrorString(err));
90 }
91 // Allocate temporary storage.
92 TF_RETURN_IF_ERROR(context->allocate_temp(
93 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
94 &temp_storage));
95 // Sort indices by keys.
96 err = gpuprim::DeviceRadixSort::SortPairs(
97 temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in, keys_out,
98 indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits,
99 cu_stream);
100 if (err != 0) {
101 return errors::Internal(
102 "Failed to launch gpuprim::DeviceRadixSort::SortPairs, "
103 "temp_storage_bytes: ",
104 temp_storage_bytes, "status: ", cudaGetErrorString(err));
105 }
106 return Status::OK();
107 }
108
109 template <typename InputIteratorT, typename OutputIteratorT>
GpuInclusivePrefixSum(OpKernelContext * context,int size,InputIteratorT input,OutputIteratorT output)110 Status GpuInclusivePrefixSum(OpKernelContext* context, int size,
111 InputIteratorT input, OutputIteratorT output) {
112 static_assert(
113 !std::is_same<typename std::remove_reference<decltype(*input)>::type,
114 bool>::value,
115 "GpuInclusivePrefixSum does not work correct with booleans, please use "
116 "TransformInputIterator to explicitly cast to an integer.");
117 if (size == 0) return Status::OK();
118 const auto& cu_stream = GetGpuStream(context);
119 size_t temp_storage_bytes;
120 auto err = gpuprim::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes,
121 input, output, size, cu_stream);
122 if (err != 0) {
123 return errors::Internal(
124 "Failed to launch gpuprim::DeviceScan::InclusiveSum to calculate "
125 "temp_storage_bytes, status: ",
126 cudaGetErrorString(err));
127 }
128 Tensor temp_storage;
129 TF_RETURN_IF_ERROR(context->allocate_temp(
130 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
131 &temp_storage));
132 err = gpuprim::DeviceScan::InclusiveSum(temp_storage.flat<int8>().data(),
133 temp_storage_bytes, input, output,
134 size, cu_stream);
135 if (err != 0) {
136 return errors::Internal(
137 "Failed to launch gpuprim::DeviceScan::InclusiveSum, "
138 "temp_storage_bytes: ",
139 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
140 }
141 return Status::OK();
142 }
143
144 // Note that this behaves deterministically for repeat calls on the same device.
145 template <typename InputIteratorT, typename OutputIteratorT,
146 typename OffsetIteratorT, typename ReduceOp, typename T>
GpuSegmentedReduce(OpKernelContext * context,int num_segments,ReduceOp reduce_op,const T & initial_value,InputIteratorT input,OffsetIteratorT segment_offsets,OutputIteratorT output)147 Status GpuSegmentedReduce(
148 OpKernelContext* context, int num_segments, ReduceOp reduce_op,
149 const T& initial_value,
150 InputIteratorT input, // [any]
151 OffsetIteratorT segment_offsets, // [num_segments + 1]
152 OutputIteratorT output) { // [num_segments]
153 if (num_segments == 0) return Status::OK();
154 const auto& cu_stream = GetGpuStream(context);
155 size_t temp_storage_bytes;
156 auto err = gpuprim::DeviceSegmentedReduce::Reduce(
157 nullptr, temp_storage_bytes, input, output, num_segments, segment_offsets,
158 segment_offsets + 1, reduce_op, initial_value, cu_stream);
159 if (err != 0) {
160 return errors::Internal(
161 "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce to calculate "
162 "temp_storage_bytes, status: ",
163 cudaGetErrorString(err));
164 }
165 Tensor temp_storage;
166 TF_RETURN_IF_ERROR(context->allocate_temp(
167 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
168 &temp_storage));
169 err = gpuprim::DeviceSegmentedReduce::Reduce(
170 temp_storage.flat<int8>().data(), temp_storage_bytes, input, output,
171 num_segments, segment_offsets, segment_offsets + 1, reduce_op,
172 initial_value, cu_stream);
173 if (err != 0) {
174 return errors::Internal(
175 "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce"
176 ", temp_storage_bytes: ",
177 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
178 }
179 return Status::OK();
180 }
181
182 } // namespace tensorflow
183
184 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
185
186 #endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
187