• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/in_topk_op.h"
23 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
24 #include "tensorflow/core/kernels/reduction_ops.h"
25 #include "tensorflow/core/util/gpu_kernel_helper.h"
26 
27 namespace tensorflow {
28 typedef Eigen::GpuDevice GPUDevice;
29 
30 namespace functor {
31 
32 // Compare each prediction in 'predictions' with a target prediction for the
33 // batch, and write result to the 'mask':
34 //  -1: If the target class is out of range, or if the prediction value is not
35 //      finite and can't be compared to target prediction (and vice versa).
36 //   0: If prediction is smaller than the target prediction for the batch.
37 //   1: If prediction is larger than the target prediction for the batch.
38 template <typename T, typename TargetT>
ComputePredictionMaskKernel(const T * __restrict__ predictions,const TargetT * __restrict__ targets,int64 * __restrict__ mask,int num_targets,int num_classes)39 __global__ void ComputePredictionMaskKernel(
40     const T* __restrict__ predictions,    // dims: [ num_targets x num_classes ]
41     const TargetT* __restrict__ targets,  // dims: [ num_targets ]
42     int64* __restrict__ mask,             // dims: [ num_targets x num_classes ]
43     int num_targets, int num_classes) {
44   GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
45     const int batch_index = i / num_classes;
46     TargetT target_idx = ldg(targets + batch_index);
47 
48     if (!FastBoundsCheck(target_idx, num_classes)) {
49       mask[i] = -1;
50       return;
51     }
52 
53     T prediction = ldg(predictions + i);
54     T target_prediction =
55         ldg(predictions + batch_index * num_classes + target_idx);
56 
57     if (!Eigen::numext::isfinite(prediction) ||
58         !Eigen::numext::isfinite(target_prediction)) {
59       mask[i] = -1;
60     } else {
61       mask[i] = prediction > target_prediction ? 1 : 0;
62     }
63   }
64 }
65 
66 // Reduce all prediction masks either to the sum of '1' for each prediction
67 // larger than the target, or to '-1' if target class in invalid of predictions
68 // in a batch have non-finite values.
69 struct MaskSum {
operator ()tensorflow::functor::MaskSum70   __host__ __device__ int64 operator()(const int64& a, const int64& b) const {
71     if (a < 0 || b < 0)
72       return -1;
73     else
74       return a + b;
75   }
76 };
77 
78 namespace reduction_op_helper {
79 template <>
80 struct IdentityValue<int64, MaskSum> {
operator ()tensorflow::functor::reduction_op_helper::IdentityValue81   int64 operator()() { return 0; }
82 };
83 
84 }  // namespace reduction_op_helper
85 
86 template <typename T, typename TargetT>
87 struct InTopKFunctor<GPUDevice, T, TargetT> {
88   template <int ndims>
89   using Dims = Eigen::DSizes<Eigen::Index, ndims>;
90 
operator ()tensorflow::functor::InTopKFunctor91   void operator()(OpKernelContext* context,
92                   typename TTypes<T, 2>::ConstTensor predictions,
93                   typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
94                   typename TTypes<bool>::Vec output) {
95     const Eigen::Index num_targets = predictions.dimension(0);
96     const Eigen::Index num_classes = predictions.dimension(1);
97 
98     OP_REQUIRES(
99         context, num_targets * num_classes < std::numeric_limits<int>::max(),
100         errors::InvalidArgument(
101             "Number of targets * number of classes must be less than INT_MAX"));
102 
103     if (num_targets == 0 || num_classes == 0) {
104       // Result is empty, so shortcut the rest of the function to avoid
105       // launching kernels with empty input.
106       return;
107     }
108 
109     // Temporary storage for a mask computed by  `ComputePredictionMaskKernel`.
110     Tensor predictions_mask;
111     OP_REQUIRES_OK(
112         context, context->allocate_temp(DT_INT64,
113                                         TensorShape({num_targets, num_classes}),
114                                         &predictions_mask));
115 
116     // Number of predictions for each target that are larger than the target
117     // prediction (or -1 if we can't compute this number, because not all
118     // predictions are finite or target class is out of range).
119     Tensor num_larger_prediction;
120     OP_REQUIRES_OK(context,
121                    context->allocate_temp(DT_INT64, TensorShape({num_targets}),
122                                           &num_larger_prediction));
123 
124     const auto& d = context->eigen_device<GPUDevice>();
125 
126     // Compute a mask for all predictions.
127     GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
128     OP_REQUIRES_OK(
129         context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
130                                  config.block_count, config.thread_per_block, 0,
131                                  d.stream(), predictions.data(), targets.data(),
132                                  predictions_mask.flat<int64_t>().data(),
133                                  num_targets, num_classes));
134 
135     // Reduce prediction masks to number of predictions larger than the target
136     // prediction, or to the negative value if we can't compute an answer.
137     {
138       auto in = predictions_mask.matrix<int64_t>();
139       auto out = num_larger_prediction.flat<int64_t>();
140 
141       ReduceImpl<int64, MaskSum, int64*, int64*, Dims<1>>(
142           context, (int64*)out.data(), (int64*)in.data(), in.rank(),
143           in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1,
144           in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), Dims<1>(1),
145           MaskSum());
146     }
147 
148     // Compute if target prediction is in top K predictions.
149     auto cnt = num_larger_prediction.flat<int64_t>();
150 
151     if (k.k_tensor != nullptr) {
152       if (k.k_tensor->dtype() == DT_INT32) {
153         output.device(d) =
154             (cnt >= cnt.constant(0)) &&
155             (cnt < k.k_tensor->flat<int32>().template cast<int64_t>().broadcast(
156                        Dims<1>(num_targets)));
157       } else {
158         output.device(d) =
159             (cnt >= cnt.constant(0)) &&
160             (cnt < k.k_tensor->flat<int64_t>().broadcast(Dims<1>(num_targets)));
161       }
162     } else {
163       output.device(d) =
164           (cnt >= cnt.constant(0)) && (cnt < targets.constant(k.k_value));
165     }
166   }
167 };
168 
169 }  // namespace functor
170 
171 // Definition of the GPU implementations declared in in_topk_op.cc.
172 #define DEFINE_GPU_KERNELS(T, TARGET_T) \
173   template struct functor::InTopKFunctor<GPUDevice, T, TARGET_T>;
174 
175 DEFINE_GPU_KERNELS(float, int32);
176 DEFINE_GPU_KERNELS(float, int64);
177 
178 #undef DEFINE_GPU_KERNELS
179 
180 }  // end namespace tensorflow
181 
182 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
183