1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif // GOOGLE_CUDA
23
24 #include "tensorflow/core/kernels/where_op.h"
25
26 #include <memory>
27 #include <numeric>
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/kernels/cuda_solvers.h"
43 #include "tensorflow/core/platform/cuda.h"
44
45 using stream_executor::cuda::ScopedActivateExecutorContext;
46 #endif // GOOGLE_CUDA
47
48 namespace tensorflow {
49
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef Eigen::GpuDevice GPUDevice;
52
53 namespace functor {
54
55 namespace {
56 template <typename T>
CountAccumulator(const T * begin,const T * end)57 int64 CountAccumulator(const T* begin, const T* end) {
58 return std::accumulate(begin, end, 0LL, [](int64 accum, const T& val) {
59 return accum + (val != T(0));
60 });
61 }
62
63 template <>
CountAccumulator(const bool * begin,const bool * end)64 int64 CountAccumulator<bool>(const bool* begin, const bool* end) {
65 return std::accumulate(begin, end, 0LL);
66 }
67
68 } // namespace
69
70 template <typename T>
71 struct NumTrue<CPUDevice, T, int64> {
Computetensorflow::functor::NumTrue72 static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
73 typename TTypes<T>::ConstFlat input,
74 TTypes<int64>::Scalar num_true) {
75 num_true() = CountAccumulator<T>(input.data(), input.data() + input.size());
76 return Status::OK();
77 }
78 };
79
80 template <int DIMS, typename T, typename TIndex>
81 struct Where<CPUDevice, DIMS, T, TIndex> {
WriteIndexRowMajortensorflow::functor::Where82 EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
83 typename TTypes<int64>::Matrix output,
84 const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
85 TIndex index) {
86 for (int i = 0; i < DIMS; ++i) {
87 output(true_n, i) = index / strides[i];
88 index -= output(true_n, i) * strides[i];
89 }
90 }
91
Computetensorflow::functor::Where92 EIGEN_ALWAYS_INLINE static Status Compute(
93 OpKernelContext* ctx, const CPUDevice& d,
94 typename TTypes<T, DIMS>::ConstTensor input,
95 typename TTypes<int64>::Matrix output, TIndex* found_true) {
96 Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
97 Eigen::DSizes<TIndex, DIMS> strides;
98
99 EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
100 static_cast<int>(Eigen::RowMajor)),
101 INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
102
103 strides[DIMS - 1] = 1;
104 for (int i = DIMS - 2; i >= 0; --i) {
105 strides[i] = strides[i + 1] * dims[i + 1];
106 }
107
108 Eigen::DenseIndex output_size = output.dimension(0);
109 for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
110 if (input.data()[n] != T(0)) {
111 if (FastBoundsCheck(*found_true, output_size)) {
112 WriteIndexRowMajor(output, strides, *found_true, n);
113 }
114 ++*found_true;
115 }
116 }
117 return Status::OK();
118 }
119 };
120
121 } // namespace functor
122
123 template <typename T>
124 class WhereCPUOp : public OpKernel {
125 public:
WhereCPUOp(OpKernelConstruction * context)126 explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}
127
Compute(OpKernelContext * context)128 void Compute(OpKernelContext* context) override {
129 const Tensor& input = context->input(0);
130
131 OP_REQUIRES(
132 context, input.dtype() != DT_HALF,
133 errors::Unimplemented("No WhereOp available for float16/half type on "
134 "CPU; dying in CPU WhereOp to avoid silently "
135 "creating costly copies from device."));
136
137 const int input_dims = input.dims();
138
139 Tensor num_true;
140 AllocatorAttributes attr;
141 attr.set_on_host(true);
142 OP_REQUIRES_OK(context, context->allocate_temp(DT_INT64, TensorShape({}),
143 &num_true, attr));
144 auto num_true_t = num_true.scalar<int64>();
145
146 Status s = functor::NumTrue<CPUDevice, T, int64>::Compute(
147 context, context->eigen_device<CPUDevice>(), input.flat<T>(),
148 num_true_t);
149 OP_REQUIRES_OK(context, s);
150 TensorShape output_shape({num_true_t(), input_dims});
151 Tensor* output = nullptr;
152 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
153
154 // TODO(ebrevdo): Replace single-threaded copy with a
155 // multithreaded block copy by getting block counts above instead
156 // of a global NumTrue, then having each block filled in in
157 // separate threads below.
158 int64 found_true = 0;
159
160 #define HANDLE_DIM(NDIM) \
161 case NDIM: { \
162 Status s = functor::Where<CPUDevice, NDIM, T, int64>::Compute( \
163 context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \
164 output->matrix<int64>(), &found_true); \
165 OP_REQUIRES_OK(context, s); \
166 } break;
167
168 switch (input_dims) {
169 HANDLE_DIM(1);
170 HANDLE_DIM(2);
171 HANDLE_DIM(3);
172 HANDLE_DIM(4);
173 HANDLE_DIM(5);
174
175 default:
176 OP_REQUIRES(context, false,
177 errors::InvalidArgument(
178 "WhereOp : Unhandled input dimensions: ", input_dims));
179 }
180 #undef HANDLE_DIM
181
182 OP_REQUIRES(
183 context, found_true == num_true_t(),
184 errors::InvalidArgument(
185 "WhereOp: Race condition between counting the number of true "
186 "elements and writing them. When counting, saw ",
187 num_true_t(), " elements; but when writing their indices, saw ",
188 found_true, " elements."));
189 }
190
191 private:
192 TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
193 };
194
195 #define REGISTER_WHERE_OP(T) \
196 REGISTER_KERNEL_BUILDER( \
197 Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>);
198
199 TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP);
200 TF_CALL_bool(REGISTER_WHERE_OP);
201
202 #undef REGISTER_WHERE_OP
203
204 #if GOOGLE_CUDA
205
206 namespace functor {
207
208 #define DECLARE_GPU_NUMTRUE(T, Tindex) \
209 template <> \
210 Status NumTrue<GPUDevice, T, Tindex>::Compute( \
211 OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \
212 TTypes<Tindex>::Scalar num_true); \
213 extern template struct NumTrue<GPUDevice, T, Tindex>
214
215 #define DECLARE_GPU_NUMTRUE_TYPE(T) \
216 DECLARE_GPU_NUMTRUE(T, int32); \
217 DECLARE_GPU_NUMTRUE(T, int64);
218
219 TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE);
220 TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE);
221
222 #undef DECLARE_GPU_NUMTRUE_TYPE
223 #undef DECLARE_GPU_NUMTRUE
224
225 #define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex) \
226 template <> \
227 Status Where<GPUDevice, Dims, T, Tindex>::Compute( \
228 OpKernelContext* ctx, const GPUDevice& d, \
229 typename TTypes<T, Dims>::ConstTensor input, \
230 typename TTypes<int64>::Matrix output, Tindex* found_true); \
231 extern template struct Where<GPUDevice, Dims, T, Tindex>;
232 #define DECLARE_GPU_WHERE(Dims, T) \
233 DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \
234 DECLARE_GPU_WHERE_INDEX(Dims, T, int64);
235
236 #define DECLARE_GPU_WHERE_TYPES(T) \
237 DECLARE_GPU_WHERE(1, T); \
238 DECLARE_GPU_WHERE(2, T); \
239 DECLARE_GPU_WHERE(3, T); \
240 DECLARE_GPU_WHERE(4, T); \
241 DECLARE_GPU_WHERE(5, T);
242
243 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES);
244
245 #undef DECLARE_GPU_WHERE_TYPES
246 #undef DECLARE_GPU_WHERE
247 #undef DECLARE_GPU_WHERE_INDEX
248
249 } // namespace functor
250
251 template <typename T>
252 class WhereGPUOp : public AsyncOpKernel {
253 public:
WhereGPUOp(OpKernelConstruction * context)254 explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
255
ComputeAsync(OpKernelContext * context,DoneCallback done)256 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
257 const Tensor& input = context->input(0);
258 const int input_dims = input.dims();
259
260 if (input.NumElements() < std::numeric_limits<int32>::max()) {
261 ComputeAsyncType<int32>(input, input_dims, context, done);
262 } else {
263 ComputeAsyncType<int64>(input, input_dims, context, done);
264 }
265 }
266
267 template <typename Tindex>
ComputeAsyncType(const Tensor & input,const int input_dims,OpKernelContext * context,DoneCallback done)268 void ComputeAsyncType(const Tensor& input, const int input_dims,
269 OpKernelContext* context, DoneCallback done) {
270 // Step 0: alloc nnz
271 // Step 1: call nnz kernel
272 // Step 2: copy nnz to host
273 // Step 3: call create_output
274 // Step 4: call where kernel
275 Tensor num_true;
276 OP_REQUIRES_OK_ASYNC(context,
277 context->allocate_temp(DataTypeToEnum<Tindex>::v(),
278 TensorShape({}), &num_true),
279 done);
280
281 auto num_true_t = num_true.scalar<Tindex>();
282
283 se::DeviceMemoryBase num_true_ptr(static_cast<void*>(num_true_t.data()));
284 // Push kernel to stream to get number of true elements.
285 const GPUDevice& d = context->eigen_device<GPUDevice>();
286 Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
287 context, d, input.flat<T>(), num_true_t);
288 OP_REQUIRES_OK_ASYNC(context, s, done);
289
290 // Copy num_true to host;
291 ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true);
292
293 auto stream = context->op_device_context()->stream();
294 OP_REQUIRES_ASYNC(
295 context,
296 stream
297 ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr,
298 sizeof(Tindex))
299 .ok(),
300 errors::Internal("WhereOp: failed to copy num_true from device"), done);
301
302 auto create_and_check_output = [context, &d, &input, input_dims,
303 num_true_host, done]() {
304 // Ensure that within the callback, the proper GPU settings are
305 // configured.
306 auto stream = context->op_device_context()->stream();
307 ScopedActivateExecutorContext scoped_activation{stream->parent()};
308
309 Tindex num_true = *num_true_host.data();
310
311 // TODO(ebrevdo): Properly copy back found_true value to CPU for
312 // validation checking. Currently Where<GPUDevice>::Compute()
313 // does not perform this copy back to CPU.
314 Tindex found_true = -1;
315
316 // Step 1: Allocate the output and perform the selection/copy.
317 Tensor* output;
318 OP_REQUIRES_OK_ASYNC(context,
319 context->allocate_output(
320 0, TensorShape({num_true, input_dims}), &output),
321 done);
322
323 #define HANDLE_DIM(NDIM) \
324 case NDIM: { \
325 Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute( \
326 context, d, input.tensor<T, NDIM>(), output->matrix<int64>(), \
327 &found_true); \
328 OP_REQUIRES_OK_ASYNC(context, s, done); \
329 } break;
330
331 switch (input_dims) {
332 HANDLE_DIM(1);
333 HANDLE_DIM(2);
334 HANDLE_DIM(3);
335 HANDLE_DIM(4);
336 HANDLE_DIM(5);
337
338 default:
339 OP_REQUIRES_ASYNC(
340 context, false,
341 errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
342 input_dims),
343 done);
344 }
345 #undef HANDLE_DIM
346
347 // TODO(ebrevdo): Fix the copy back to host.
348
349 // OP_REQUIRES_ASYNC(
350 // context, found_true == num_true,
351 // errors::InvalidArgument(
352 // "WhereOp: Race condition between counting the number of true "
353 // "elements and writing them. When counting, saw ",
354 // num_true, " elements; but when writing their indices, saw ",
355 // found_true, " elements."),
356 // done);
357
358 done();
359 };
360 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
361 stream, create_and_check_output);
362 }
363
364 private:
365 TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
366 };
367
368 #define REGISTER_GPU_WHERE_OP(T) \
369 REGISTER_KERNEL_BUILDER( \
370 Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>);
371
372 TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP);
373 REGISTER_KERNEL_BUILDER(Name("Where")
374 .Device(DEVICE_GPU)
375 .TypeConstraint<int32>("T")
376 .HostMemory("input")
377 .HostMemory("index"),
378 WhereCPUOp<int32>);
379
380 #undef REGISTER_GPU_WHERE_OP
381
382 #endif // GOOGLE_CUDA
383
384 } // namespace tensorflow
385