• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/where_op.h"
25 
26 #include <memory>
27 #include <numeric>
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/kernels/cuda_solvers.h"
43 #include "tensorflow/core/platform/cuda.h"
44 
45 using stream_executor::cuda::ScopedActivateExecutorContext;
46 #endif  // GOOGLE_CUDA
47 
48 namespace tensorflow {
49 
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef Eigen::GpuDevice GPUDevice;
52 
53 namespace functor {
54 
55 namespace {
56 template <typename T>
CountAccumulator(const T * begin,const T * end)57 int64 CountAccumulator(const T* begin, const T* end) {
58   return std::accumulate(begin, end, 0LL, [](int64 accum, const T& val) {
59     return accum + (val != T(0));
60   });
61 }
62 
63 template <>
CountAccumulator(const bool * begin,const bool * end)64 int64 CountAccumulator<bool>(const bool* begin, const bool* end) {
65   return std::accumulate(begin, end, 0LL);
66 }
67 
68 }  // namespace
69 
70 template <typename T>
71 struct NumTrue<CPUDevice, T, int64> {
Computetensorflow::functor::NumTrue72   static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
73                         typename TTypes<T>::ConstFlat input,
74                         TTypes<int64>::Scalar num_true) {
75     num_true() = CountAccumulator<T>(input.data(), input.data() + input.size());
76     return Status::OK();
77   }
78 };
79 
80 template <int DIMS, typename T, typename TIndex>
81 struct Where<CPUDevice, DIMS, T, TIndex> {
WriteIndexRowMajortensorflow::functor::Where82   EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
83       typename TTypes<int64>::Matrix output,
84       const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
85       TIndex index) {
86     for (int i = 0; i < DIMS; ++i) {
87       output(true_n, i) = index / strides[i];
88       index -= output(true_n, i) * strides[i];
89     }
90   }
91 
Computetensorflow::functor::Where92   EIGEN_ALWAYS_INLINE static Status Compute(
93       OpKernelContext* ctx, const CPUDevice& d,
94       typename TTypes<T, DIMS>::ConstTensor input,
95       typename TTypes<int64>::Matrix output, TIndex* found_true) {
96     Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
97     Eigen::DSizes<TIndex, DIMS> strides;
98 
99     EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
100                          static_cast<int>(Eigen::RowMajor)),
101                         INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
102 
103     strides[DIMS - 1] = 1;
104     for (int i = DIMS - 2; i >= 0; --i) {
105       strides[i] = strides[i + 1] * dims[i + 1];
106     }
107 
108     Eigen::DenseIndex output_size = output.dimension(0);
109     for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
110       if (input.data()[n] != T(0)) {
111         if (FastBoundsCheck(*found_true, output_size)) {
112           WriteIndexRowMajor(output, strides, *found_true, n);
113         }
114         ++*found_true;
115       }
116     }
117     return Status::OK();
118   }
119 };
120 
121 }  // namespace functor
122 
123 template <typename T>
124 class WhereCPUOp : public OpKernel {
125  public:
WhereCPUOp(OpKernelConstruction * context)126   explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}
127 
Compute(OpKernelContext * context)128   void Compute(OpKernelContext* context) override {
129     const Tensor& input = context->input(0);
130 
131     OP_REQUIRES(
132         context, input.dtype() != DT_HALF,
133         errors::Unimplemented("No WhereOp available for float16/half type on "
134                               "CPU; dying in CPU WhereOp to avoid silently "
135                               "creating costly copies from device."));
136 
137     const int input_dims = input.dims();
138 
139     Tensor num_true;
140     AllocatorAttributes attr;
141     attr.set_on_host(true);
142     OP_REQUIRES_OK(context, context->allocate_temp(DT_INT64, TensorShape({}),
143                                                    &num_true, attr));
144     auto num_true_t = num_true.scalar<int64>();
145 
146     Status s = functor::NumTrue<CPUDevice, T, int64>::Compute(
147         context, context->eigen_device<CPUDevice>(), input.flat<T>(),
148         num_true_t);
149     OP_REQUIRES_OK(context, s);
150     TensorShape output_shape({num_true_t(), input_dims});
151     Tensor* output = nullptr;
152     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
153 
154     // TODO(ebrevdo): Replace single-threaded copy with a
155     // multithreaded block copy by getting block counts above instead
156     // of a global NumTrue, then having each block filled in in
157     // separate threads below.
158     int64 found_true = 0;
159 
160 #define HANDLE_DIM(NDIM)                                                      \
161   case NDIM: {                                                                \
162     Status s = functor::Where<CPUDevice, NDIM, T, int64>::Compute(            \
163         context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \
164         output->matrix<int64>(), &found_true);                                \
165     OP_REQUIRES_OK(context, s);                                               \
166   } break;
167 
168     switch (input_dims) {
169       HANDLE_DIM(1);
170       HANDLE_DIM(2);
171       HANDLE_DIM(3);
172       HANDLE_DIM(4);
173       HANDLE_DIM(5);
174 
175       default:
176         OP_REQUIRES(context, false,
177                     errors::InvalidArgument(
178                         "WhereOp : Unhandled input dimensions: ", input_dims));
179     }
180 #undef HANDLE_DIM
181 
182     OP_REQUIRES(
183         context, found_true == num_true_t(),
184         errors::InvalidArgument(
185             "WhereOp: Race condition between counting the number of true "
186             "elements and writing them.  When counting, saw ",
187             num_true_t(), " elements; but when writing their indices, saw ",
188             found_true, " elements."));
189   }
190 
191  private:
192   TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
193 };
194 
195 #define REGISTER_WHERE_OP(T) \
196   REGISTER_KERNEL_BUILDER(   \
197       Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>);
198 
199 TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP);
200 TF_CALL_bool(REGISTER_WHERE_OP);
201 
202 #undef REGISTER_WHERE_OP
203 
204 #if GOOGLE_CUDA
205 
206 namespace functor {
207 
208 #define DECLARE_GPU_NUMTRUE(T, Tindex)                                      \
209   template <>                                                               \
210   Status NumTrue<GPUDevice, T, Tindex>::Compute(                            \
211       OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \
212       TTypes<Tindex>::Scalar num_true);                                     \
213   extern template struct NumTrue<GPUDevice, T, Tindex>
214 
215 #define DECLARE_GPU_NUMTRUE_TYPE(T) \
216   DECLARE_GPU_NUMTRUE(T, int32);    \
217   DECLARE_GPU_NUMTRUE(T, int64);
218 
219 TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE);
220 TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE);
221 
222 #undef DECLARE_GPU_NUMTRUE_TYPE
223 #undef DECLARE_GPU_NUMTRUE
224 
225 #define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex)                  \
226   template <>                                                     \
227   Status Where<GPUDevice, Dims, T, Tindex>::Compute(              \
228       OpKernelContext* ctx, const GPUDevice& d,                   \
229       typename TTypes<T, Dims>::ConstTensor input,                \
230       typename TTypes<int64>::Matrix output, Tindex* found_true); \
231   extern template struct Where<GPUDevice, Dims, T, Tindex>;
232 #define DECLARE_GPU_WHERE(Dims, T)         \
233   DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \
234   DECLARE_GPU_WHERE_INDEX(Dims, T, int64);
235 
236 #define DECLARE_GPU_WHERE_TYPES(T) \
237   DECLARE_GPU_WHERE(1, T);         \
238   DECLARE_GPU_WHERE(2, T);         \
239   DECLARE_GPU_WHERE(3, T);         \
240   DECLARE_GPU_WHERE(4, T);         \
241   DECLARE_GPU_WHERE(5, T);
242 
243 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES);
244 
245 #undef DECLARE_GPU_WHERE_TYPES
246 #undef DECLARE_GPU_WHERE
247 #undef DECLARE_GPU_WHERE_INDEX
248 
249 }  // namespace functor
250 
251 template <typename T>
252 class WhereGPUOp : public AsyncOpKernel {
253  public:
WhereGPUOp(OpKernelConstruction * context)254   explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
255 
ComputeAsync(OpKernelContext * context,DoneCallback done)256   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
257     const Tensor& input = context->input(0);
258     const int input_dims = input.dims();
259 
260     if (input.NumElements() < std::numeric_limits<int32>::max()) {
261       ComputeAsyncType<int32>(input, input_dims, context, done);
262     } else {
263       ComputeAsyncType<int64>(input, input_dims, context, done);
264     }
265   }
266 
267   template <typename Tindex>
ComputeAsyncType(const Tensor & input,const int input_dims,OpKernelContext * context,DoneCallback done)268   void ComputeAsyncType(const Tensor& input, const int input_dims,
269                         OpKernelContext* context, DoneCallback done) {
270     // Step 0: alloc nnz
271     // Step 1: call nnz kernel
272     // Step 2: copy nnz to host
273     // Step 3: call create_output
274     // Step 4: call where kernel
275     Tensor num_true;
276     OP_REQUIRES_OK_ASYNC(context,
277                          context->allocate_temp(DataTypeToEnum<Tindex>::v(),
278                                                 TensorShape({}), &num_true),
279                          done);
280 
281     auto num_true_t = num_true.scalar<Tindex>();
282 
283     se::DeviceMemoryBase num_true_ptr(static_cast<void*>(num_true_t.data()));
284     // Push kernel to stream to get number of true elements.
285     const GPUDevice& d = context->eigen_device<GPUDevice>();
286     Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
287         context, d, input.flat<T>(), num_true_t);
288     OP_REQUIRES_OK_ASYNC(context, s, done);
289 
290     // Copy num_true to host;
291     ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true);
292 
293     auto stream = context->op_device_context()->stream();
294     OP_REQUIRES_ASYNC(
295         context,
296         stream
297             ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr,
298                          sizeof(Tindex))
299             .ok(),
300         errors::Internal("WhereOp: failed to copy num_true from device"), done);
301 
302     auto create_and_check_output = [context, &d, &input, input_dims,
303                                     num_true_host, done]() {
304       // Ensure that within the callback, the proper GPU settings are
305       // configured.
306       auto stream = context->op_device_context()->stream();
307       ScopedActivateExecutorContext scoped_activation{stream->parent()};
308 
309       Tindex num_true = *num_true_host.data();
310 
311       // TODO(ebrevdo): Properly copy back found_true value to CPU for
312       // validation checking.  Currently Where<GPUDevice>::Compute()
313       // does not perform this copy back to CPU.
314       Tindex found_true = -1;
315 
316       // Step 1: Allocate the output and perform the selection/copy.
317       Tensor* output;
318       OP_REQUIRES_OK_ASYNC(context,
319                            context->allocate_output(
320                                0, TensorShape({num_true, input_dims}), &output),
321                            done);
322 
323 #define HANDLE_DIM(NDIM)                                              \
324   case NDIM: {                                                        \
325     Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute(   \
326         context, d, input.tensor<T, NDIM>(), output->matrix<int64>(), \
327         &found_true);                                                 \
328     OP_REQUIRES_OK_ASYNC(context, s, done);                           \
329   } break;
330 
331       switch (input_dims) {
332         HANDLE_DIM(1);
333         HANDLE_DIM(2);
334         HANDLE_DIM(3);
335         HANDLE_DIM(4);
336         HANDLE_DIM(5);
337 
338         default:
339           OP_REQUIRES_ASYNC(
340               context, false,
341               errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
342                                       input_dims),
343               done);
344       }
345 #undef HANDLE_DIM
346 
347       // TODO(ebrevdo): Fix the copy back to host.
348 
349       // OP_REQUIRES_ASYNC(
350       //     context, found_true == num_true,
351       //     errors::InvalidArgument(
352       //         "WhereOp: Race condition between counting the number of true "
353       //         "elements and writing them.  When counting, saw ",
354       //         num_true, " elements; but when writing their indices, saw ",
355       //         found_true, " elements."),
356       //     done);
357 
358       done();
359     };
360     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
361         stream, create_and_check_output);
362   }
363 
364  private:
365   TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
366 };
367 
368 #define REGISTER_GPU_WHERE_OP(T) \
369   REGISTER_KERNEL_BUILDER(       \
370       Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>);
371 
372 TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP);
373 REGISTER_KERNEL_BUILDER(Name("Where")
374                             .Device(DEVICE_GPU)
375                             .TypeConstraint<int32>("T")
376                             .HostMemory("input")
377                             .HostMemory("index"),
378                         WhereCPUOp<int32>);
379 
380 #undef REGISTER_GPU_WHERE_OP
381 
382 #endif  // GOOGLE_CUDA
383 
384 }  // namespace tensorflow
385