1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #define EIGEN_USE_GPU
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/kernels/gpu_prim.h"
28 #include "tensorflow/core/kernels/where_op.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::GpuDevice GPUDevice;
36
37 namespace functor {
38
39 template <int NDIM, typename TIndex>
PropagateWhereIndicesKernel(const TIndex output_rows,const typename Eigen::array<TIndex,NDIM> strides,int64 * __restrict__ output)40 __global__ void PropagateWhereIndicesKernel(
41 const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides,
42 int64* __restrict__ output) {
43 // TODO(ebrevdo): Use a multi-dimensional loop, increasing the
44 // dimensions of individual indices manually, instead of relying on
45 // a scalar loop variable and using integer division.
46 GPU_1D_KERNEL_LOOP(i, output_rows) {
47 TIndex index_value = ldg(output + NDIM * i);
48 #pragma unroll
49 for (int c = 0; c < NDIM; ++c) {
50 *(output + NDIM * i + c) = index_value / strides[c];
51 index_value %= strides[c];
52 }
53 }
54 }
55
56 namespace {
57
58 template <typename T>
59 struct IsNonzero {
IsNonzeroIsNonzero60 EIGEN_DEVICE_FUNC IsNonzero() : zero(T(0)) {}
operatorIsNonzero61 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x) const {
62 return (x != zero);
63 }
64 const T zero;
65 };
66
67 template <typename T, typename TIndex>
68 struct CubDeviceReduceCount {
operatorCubDeviceReduceCount69 gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
70 const T* d_in, TIndex* d_out, int num_items,
71 gpuStream_t stream = 0,
72 bool debug_synchronous = false) {
73 IsNonzero<T> is_nonzero;
74 gpuprim::TransformInputIterator<bool, IsNonzero<T>, const T*>
75 is_nonzero_iter(d_in, is_nonzero);
76 return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
77 is_nonzero_iter, d_out, num_items, stream,
78 debug_synchronous);
79 }
80 };
81
82 template <typename TIndex>
83 struct CubDeviceReduceCount<bool, TIndex> {
84 gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
85 const bool* d_in, TIndex* d_out, int num_items,
86 gpuStream_t stream = 0,
87 bool debug_synchronous = false) {
88 return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in,
89 d_out, num_items, stream,
90 debug_synchronous);
91 }
92 };
93
94 template <typename T, typename TIndex, typename OutputIterator,
95 bool IsConvertibleToBool>
96 struct CubDeviceSelectFlaggedCounter;
97
98 template <typename T, typename TIndex, typename OutputIterator>
99 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
100 false /*IsConvertibleToBool*/> {
101 gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
102 const T* d_flags, OutputIterator d_out,
103 TIndex* d_num_selected_out, int num_items,
104 gpuStream_t stream = 0,
105 bool debug_synchronous = false) {
106 gpuprim::CountingInputIterator<TIndex> select_counter(0);
107 IsNonzero<T> is_nonzero;
108 gpuprim::TransformInputIterator<bool, IsNonzero<T>, const T*>
109 is_nonzero_iter(d_flags, is_nonzero);
110 return gpuprim::DeviceSelect::Flagged(
111 d_temp_storage, temp_storage_bytes, select_counter /*d_in*/,
112 is_nonzero_iter /*d_flags*/, d_out, d_num_selected_out, num_items,
113 stream, debug_synchronous);
114 }
115 };
116
117 template <typename T, typename TIndex, typename OutputIterator>
118 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
119 true /*IsConvertibleToBool*/> {
120 gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
121 const T* d_flags, OutputIterator d_out,
122 TIndex* d_num_selected_out, int num_items,
123 gpuStream_t stream = 0,
124 bool debug_synchronous = false) {
125 gpuprim::CountingInputIterator<TIndex> select_counter(0);
126 return gpuprim::DeviceSelect::Flagged(
127 d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, d_flags,
128 d_out, d_num_selected_out, num_items, stream, debug_synchronous);
129 }
130 };
131
132 } // namespace
133
134 template <typename T, typename TIndex>
135 struct NumTrue<GPUDevice, T, TIndex> {
136 EIGEN_ALWAYS_INLINE static Status Compute(
137 OpKernelContext* ctx, const GPUDevice& d,
138 typename TTypes<T>::ConstFlat input,
139 typename TTypes<TIndex>::UnalignedScalar num_true) {
140 const auto& cu_stream = GetGpuStream(ctx);
141
142 std::size_t temp_storage_bytes = 0;
143 const T* input_data = input.data();
144 TIndex* num_true_data = num_true.data();
145
146 // TODO(ebrevdo): sum doesn't work; perhaps need a different
147 // iterator?
148 auto reducer = CubDeviceReduceCount<T, TIndex>();
149 auto first_success = reducer(/*temp_storage*/ nullptr, temp_storage_bytes,
150 /*d_in*/ input_data,
151 /*d_out*/ num_true_data,
152 /*num_items*/ input.size(),
153 /*stream*/ cu_stream);
154
155 if (first_success != gpuSuccess) {
156 return errors::Internal(
157 "WhereOp: Could not launch gpuprim::DeviceReduce::Sum to calculate "
158 "temp_storage_bytes, status: ",
159 GpuGetErrorString(first_success));
160 }
161
162 Tensor temp_storage;
163 TF_RETURN_IF_ERROR(ctx->allocate_temp(
164 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
165 &temp_storage));
166
167 auto second_success = reducer(
168 /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
169 /*d_in*/ input_data,
170 /*d_out*/ num_true_data,
171 /*num_items*/ input.size(),
172 /*stream*/ cu_stream);
173
174 if (second_success != gpuSuccess) {
175 return errors::Internal(
176 "WhereOp: Could not launch gpuprim::DeviceReduce::Sum to count "
177 "number of true / nonzero indices. temp_storage_bytes: ",
178 temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
179 }
180
181 return Status::OK();
182 }
183 };
184
185 #define NUMTRUE_GPU_FUNCTOR(T) \
186 template struct NumTrue<GPUDevice, T, int32>; \
187 template struct NumTrue<GPUDevice, T, int64>;
188
189 // We only need to declare the NumTrue functor once, but this file is
190 // included from where_op_gpu_impl_X.cu.cc for X=1,2,...
191 // Only declare for X = 1.
192 #if GPU_PROVIDED_DIM == 1
193
194 TF_CALL_WHERE_GPU_TYPES(NUMTRUE_GPU_FUNCTOR);
195
196 #endif // GPU_PROVIDED_DIM == 1
197
198 #undef NUMTRUE_GPU_FUNCTOR
199
200 template <int NDIM>
201 class WhereOutputIterator {
202 public:
203 // Required iterator traits
204 typedef WhereOutputIterator self_type;
205 typedef std::ptrdiff_t difference_type;
206 typedef void value_type;
207 typedef void pointer;
208 typedef int64& reference;
209
210 #if (THRUST_VERSION >= 100700)
211 // Use Thrust's iterator categories so we can use these iterators in Thrust
212 // 1.7 (or newer) methods
213 typedef typename thrust::detail::iterator_facade_category<
214 thrust::device_system_tag, thrust::random_access_traversal_tag,
215 value_type,
216 reference>::type iterator_category; ///< The iterator category
217 #else
218 typedef std::random_access_iterator_tag
219 iterator_category; ///< The iterator category
220 #endif // THRUST_VERSION
221
222 WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row)
223 : ptr_(ptr), max_row_(max_row) {}
224
225 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const {
226 // If the selection mechanism finds too many true values (because
227 // the input tensor changed between allocation of output and now),
228 // we may accidentally try to write past the allowable memory. If
229 // valid is false, then we don't do this. Instead, we'll read off
230 // the number of items found in Flagged()'s d_num_selected_out at
231 // the end and confirm that it matches the number of rows of output.
232 const bool valid = FastBoundsCheck(n, max_row_);
233 return *(ptr_ + (valid ? (NDIM * n) : 0));
234 }
235
236 private:
237 int64* ptr_;
238 const Eigen::DenseIndex max_row_;
239 };
240
241 template <typename TIndex, typename T, int NDIM>
242 Eigen::array<TIndex, NDIM> CalculateStrides(
243 typename TTypes<T, NDIM>::ConstTensor input) {
244 const Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
245 Eigen::array<TIndex, NDIM> strides;
246 EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
247 static_cast<int>(Eigen::RowMajor)),
248 INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
249 strides[NDIM - 1] = 1;
250 for (int i = NDIM - 2; i >= 0; --i) {
251 strides[i] = strides[i + 1] * dims[i + 1];
252 }
253 return strides;
254 }
255
256 template <int NDIM, typename T, typename TIndex>
257 struct Where<GPUDevice, NDIM, T, TIndex> {
258 EIGEN_ALWAYS_INLINE static Status Compute(
259 OpKernelContext* ctx, const GPUDevice& d,
260 typename TTypes<T, NDIM>::ConstTensor input,
261 typename TTypes<int64>::Matrix output, TIndex* found_true_host) {
262 if (output.dimension(0) == 0) {
263 // Nothing to do.
264 return Status::OK();
265 }
266
267 const auto& cu_stream = GetGpuStream(ctx);
268
269 std::size_t temp_storage_bytes = 0;
270
271 Tensor found_true_t;
272 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<TIndex>::v(),
273 TensorShape({}), &found_true_t));
274 TIndex* found_true_device = found_true_t.scalar<TIndex>().data();
275
276 WhereOutputIterator<NDIM> output_iterator(
277 output.data(),
278 /* max_row */ output.dimension(0));
279
280 typedef std::decay<T> DT;
281 CubDeviceSelectFlaggedCounter<
282 T, TIndex, decltype(output_iterator) /*OutputIterator*/,
283 std::is_convertible<DT, bool>::value /*IsConvertibleToBool*/>
284 counter;
285 auto first_success = counter(/*temp_storage*/ nullptr, temp_storage_bytes,
286 /*d_flags*/ input.data(),
287 /*d_out*/ output_iterator,
288 /*d_num_selected_out*/ found_true_device,
289 /*num_items*/ input.size(),
290 /*stream*/ cu_stream);
291 if (first_success != gpuSuccess) {
292 return errors::Internal(
293 "WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to "
294 "calculate "
295 "temp_storage_bytes, status: ",
296 GpuGetErrorString(first_success));
297 }
298
299 Tensor temp_storage;
300 TF_RETURN_IF_ERROR(ctx->allocate_temp(
301 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
302 &temp_storage));
303
304 auto second_success = counter(
305 /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
306 /*d_flags*/ input.data(),
307 /*d_out*/ output_iterator,
308 /*d_num_selected_out*/ found_true_device,
309 /*num_items*/ input.size(),
310 /*stream*/ cu_stream);
311
312 if (second_success != gpuSuccess) {
313 return errors::Internal(
314 "WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to copy "
315 "indices out, status: ",
316 GpuGetErrorString(second_success));
317 }
318
319 // TODO(ebrevdo): Find a way to synchronously copy back data from
320 // found_true_device to *found_true_host.
321
322 const Eigen::array<TIndex, NDIM> strides =
323 CalculateStrides<TIndex, T, NDIM>(input);
324 const TIndex output_rows = output.dimension(0);
325 GpuLaunchConfig config = GetGpuLaunchConfig(output_rows, d);
326 TF_CHECK_OK(GpuLaunchKernel(PropagateWhereIndicesKernel<NDIM, TIndex>,
327 config.block_count, config.thread_per_block, 0,
328 d.stream(), output_rows, strides,
329 output.data()));
330
331 return Status::OK();
332 }
333 };
334
335 #define DECLARE_GPU_SPEC_INDEX(Dims, T, TIndex) \
336 template struct Where<GPUDevice, Dims, T, TIndex>
337
338 #define DECLARE_GPU_SPEC(T) \
339 DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int32); \
340 DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int64)
341
342 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
343
344 #undef DECLARE_GPU_SPEC
345 #undef DECLARE_GPU_SPEC_INDEX
346
347 } // namespace functor
348
349 } // namespace tensorflow
350
351 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
352
353 #endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
354