1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/argmax_op.h" 25 26 #include <memory> 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/bounds_check.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/register_types.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/tensor_shape.h" 33 #include "tensorflow/core/framework/tensor_types.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T, typename Tout, typename ArgFunctor> 44 class ArgOp : public OpKernel { 45 public: ArgOp(OpKernelConstruction * context)46 explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {} 47 Compute(OpKernelContext * context)48 void Compute(OpKernelContext* context) override { 49 const Tensor& input = context->input(0); 50 const Tensor& dimension = context->input(1); 51 52 OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()), 53 errors::InvalidArgument( 54 "dim must be a scalar, but received tensor of shape: ", 55 dimension.shape().DebugString())); 56 57 const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()()); 58 const int input_dims = input.dims(); 59 60 int axis = dim < 0 ? dim + input_dims : dim; 61 62 OP_REQUIRES(context, FastBoundsCheck(axis, input_dims), 63 errors::InvalidArgument("Expected dimension in the range [", 64 -input_dims, ", ", input_dims, 65 "), but got ", dim)); 66 OP_REQUIRES( 67 context, input.dim_size(axis) > 0, 68 errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ", 69 input.shape().DebugString())); 70 71 TensorShape output_shape; 72 const TensorShape& input_shape = input.shape(); 73 for (int d = 0; d < input_dims - 1; ++d) { 74 output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); 75 } 76 Tensor* output = nullptr; 77 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 78 79 if (output_shape.num_elements() == 0) { 80 return; 81 } 82 83 #define HANDLE_DIM(NDIM) \ 84 case NDIM: \ 85 ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \ 86 input.tensor<T, NDIM>(), axis, \ 87 output->tensor<Tout, NDIM - 1>()); \ 88 break; 89 90 switch (input_dims) { 91 HANDLE_DIM(1); 92 HANDLE_DIM(2); 93 HANDLE_DIM(3); 94 HANDLE_DIM(4); 95 HANDLE_DIM(5); 96 97 default: 98 OP_REQUIRES(context, false, 99 errors::InvalidArgument( 100 "ArgOp : Unhandled input dimensions: ", input_dims)); 101 } 102 } 103 #undef HANDLE_DIM 104 105 private: 106 TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); 107 }; 108 109 template <typename Device, typename T, typename Tout> 110 class ArgMaxOp 111 : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > { 112 public: ArgMaxOp(OpKernelConstruction * context)113 explicit ArgMaxOp(OpKernelConstruction* context) 114 : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {} 115 }; 116 117 template <typename Device, typename T, typename Tout> 118 class ArgMinOp 119 : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > { 120 public: ArgMinOp(OpKernelConstruction * context)121 explicit ArgMinOp(OpKernelConstruction* context) 122 : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {} 123 }; 124 125 #define REGISTER_ARGMAX(type) \ 126 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 127 .Device(DEVICE_CPU) \ 128 .TypeConstraint<type>("T") \ 129 .TypeConstraint<int64>("output_type") \ 130 .HostMemory("dimension"), \ 131 ArgMaxOp<CPUDevice, type, int64>); \ 132 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 133 .Device(DEVICE_CPU) \ 134 .TypeConstraint<type>("T") \ 135 .TypeConstraint<int64>("output_type") \ 136 .HostMemory("dimension"), \ 137 ArgMinOp<CPUDevice, type, int64>); \ 138 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 139 .Device(DEVICE_CPU) \ 140 .TypeConstraint<type>("T") \ 141 .TypeConstraint<int32>("output_type") \ 142 .HostMemory("dimension"), \ 143 ArgMaxOp<CPUDevice, type, int32>); \ 144 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 145 .Device(DEVICE_CPU) \ 146 .TypeConstraint<type>("T") \ 147 .TypeConstraint<int32>("output_type") \ 148 .HostMemory("dimension"), \ 149 ArgMinOp<CPUDevice, type, int32>); 150 151 TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); 152 153 #if GOOGLE_CUDA 154 155 // Forward declarations of the functor specializations for GPU. 156 namespace functor { 157 158 #define DECLARE_GPU_SPEC(T, Tout, Dims) \ 159 template <> \ 160 void ArgMax<GPUDevice, T, Tout>::Reduce##Dims( \ 161 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ 162 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \ 163 template <> \ 164 void ArgMin<GPUDevice, T, Tout>::Reduce##Dims( \ 165 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ 166 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); 167 168 #define DECLARE_GPU_SPECS(T) \ 169 DECLARE_GPU_SPEC(T, int64, 1); \ 170 DECLARE_GPU_SPEC(T, int64, 2); \ 171 DECLARE_GPU_SPEC(T, int64, 3); \ 172 DECLARE_GPU_SPEC(T, int64, 4); \ 173 DECLARE_GPU_SPEC(T, int64, 5); \ 174 DECLARE_GPU_SPEC(T, int32, 1); \ 175 DECLARE_GPU_SPEC(T, int32, 2); \ 176 DECLARE_GPU_SPEC(T, int32, 3); \ 177 DECLARE_GPU_SPEC(T, int32, 4); \ 178 DECLARE_GPU_SPEC(T, int32, 5); 179 180 #define DECLARE_GPU_CLASS(T) \ 181 extern template struct ArgMax<GPUDevice, T, int64>; \ 182 extern template struct ArgMin<GPUDevice, T, int64>; \ 183 extern template struct ArgMax<GPUDevice, T, int32>; \ 184 extern template struct ArgMin<GPUDevice, T, int32>; 185 186 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 187 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); 188 189 #undef DECLARE_GPU_SPECS 190 #undef DECLARE_GPU_CLASS 191 192 } // namespace functor 193 194 // Registration of the GPU implementations. 195 #define REGISTER_ARGMAX_GPU(type) \ 196 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 197 .Device(DEVICE_GPU) \ 198 .TypeConstraint<type>("T") \ 199 .TypeConstraint<int64>("output_type") \ 200 .TypeConstraint<int32>("Tidx") \ 201 .HostMemory("dimension"), \ 202 ArgMaxOp<GPUDevice, type, int64>); \ 203 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 204 .Device(DEVICE_GPU) \ 205 .TypeConstraint<type>("T") \ 206 .TypeConstraint<int64>("output_type") \ 207 .TypeConstraint<int32>("Tidx") \ 208 .HostMemory("dimension"), \ 209 ArgMinOp<GPUDevice, type, int64>); \ 210 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 211 .Device(DEVICE_GPU) \ 212 .TypeConstraint<type>("T") \ 213 .TypeConstraint<int32>("output_type") \ 214 .TypeConstraint<int32>("Tidx") \ 215 .HostMemory("dimension"), \ 216 ArgMaxOp<GPUDevice, type, int32>); \ 217 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 218 .Device(DEVICE_GPU) \ 219 .TypeConstraint<type>("T") \ 220 .TypeConstraint<int32>("output_type") \ 221 .TypeConstraint<int32>("Tidx") \ 222 .HostMemory("dimension"), \ 223 ArgMinOp<GPUDevice, type, int32>); 224 225 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); 226 227 #undef REGISTER_ARGMAX_GPU 228 229 #endif // GOOGLE_CUDA 230 231 } // namespace tensorflow 232