• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/argmax_op.h"
25 
26 #include <memory>
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename T, typename Tout, typename ArgFunctor>
44 class ArgOp : public OpKernel {
45  public:
ArgOp(OpKernelConstruction * context)46   explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {}
47 
Compute(OpKernelContext * context)48   void Compute(OpKernelContext* context) override {
49     const Tensor& input = context->input(0);
50     const Tensor& dimension = context->input(1);
51 
52     OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()),
53                 errors::InvalidArgument(
54                     "dim must be a scalar, but received tensor of shape: ",
55                     dimension.shape().DebugString()));
56 
57     const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
58     const int input_dims = input.dims();
59 
60     int axis = dim < 0 ? dim + input_dims : dim;
61 
62     OP_REQUIRES(context, FastBoundsCheck(axis, input_dims),
63                 errors::InvalidArgument("Expected dimension in the range [",
64                                         -input_dims, ", ", input_dims,
65                                         "), but got ", dim));
66     OP_REQUIRES(
67         context, input.dim_size(axis) > 0,
68         errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
69                                 input.shape().DebugString()));
70 
71     TensorShape output_shape;
72     const TensorShape& input_shape = input.shape();
73     for (int d = 0; d < input_dims - 1; ++d) {
74       output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
75     }
76     Tensor* output = nullptr;
77     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
78 
79     if (output_shape.num_elements() == 0) {
80       return;
81     }
82 
83 #define HANDLE_DIM(NDIM)                                        \
84   case NDIM:                                                    \
85     ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(),   \
86                              input.tensor<T, NDIM>(), axis,     \
87                              output->tensor<Tout, NDIM - 1>()); \
88     break;
89 
90     switch (input_dims) {
91       HANDLE_DIM(1);
92       HANDLE_DIM(2);
93       HANDLE_DIM(3);
94       HANDLE_DIM(4);
95       HANDLE_DIM(5);
96 
97       default:
98         OP_REQUIRES(context, false,
99                     errors::InvalidArgument(
100                         "ArgOp : Unhandled input dimensions: ", input_dims));
101     }
102   }
103 #undef HANDLE_DIM
104 
105  private:
106   TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
107 };
108 
109 template <typename Device, typename T, typename Tout>
110 class ArgMaxOp
111     : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > {
112  public:
ArgMaxOp(OpKernelConstruction * context)113   explicit ArgMaxOp(OpKernelConstruction* context)
114       : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {}
115 };
116 
117 template <typename Device, typename T, typename Tout>
118 class ArgMinOp
119     : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > {
120  public:
ArgMinOp(OpKernelConstruction * context)121   explicit ArgMinOp(OpKernelConstruction* context)
122       : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {}
123 };
124 
125 #define REGISTER_ARGMAX(type)                                       \
126   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
127                               .Device(DEVICE_CPU)                   \
128                               .TypeConstraint<type>("T")            \
129                               .TypeConstraint<int64>("output_type") \
130                               .HostMemory("dimension"),             \
131                           ArgMaxOp<CPUDevice, type, int64>);        \
132   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
133                               .Device(DEVICE_CPU)                   \
134                               .TypeConstraint<type>("T")            \
135                               .TypeConstraint<int64>("output_type") \
136                               .HostMemory("dimension"),             \
137                           ArgMinOp<CPUDevice, type, int64>);        \
138   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
139                               .Device(DEVICE_CPU)                   \
140                               .TypeConstraint<type>("T")            \
141                               .TypeConstraint<int32>("output_type") \
142                               .HostMemory("dimension"),             \
143                           ArgMaxOp<CPUDevice, type, int32>);        \
144   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
145                               .Device(DEVICE_CPU)                   \
146                               .TypeConstraint<type>("T")            \
147                               .TypeConstraint<int32>("output_type") \
148                               .HostMemory("dimension"),             \
149                           ArgMinOp<CPUDevice, type, int32>);
150 
151 TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
152 
153 #if GOOGLE_CUDA
154 
155 // Forward declarations of the functor specializations for GPU.
156 namespace functor {
157 
158 #define DECLARE_GPU_SPEC(T, Tout, Dims)                                       \
159   template <>                                                                 \
160   void ArgMax<GPUDevice, T, Tout>::Reduce##Dims(                              \
161       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input,        \
162       const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \
163   template <>                                                                 \
164   void ArgMin<GPUDevice, T, Tout>::Reduce##Dims(                              \
165       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input,        \
166       const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output);
167 
168 #define DECLARE_GPU_SPECS(T)     \
169   DECLARE_GPU_SPEC(T, int64, 1); \
170   DECLARE_GPU_SPEC(T, int64, 2); \
171   DECLARE_GPU_SPEC(T, int64, 3); \
172   DECLARE_GPU_SPEC(T, int64, 4); \
173   DECLARE_GPU_SPEC(T, int64, 5); \
174   DECLARE_GPU_SPEC(T, int32, 1); \
175   DECLARE_GPU_SPEC(T, int32, 2); \
176   DECLARE_GPU_SPEC(T, int32, 3); \
177   DECLARE_GPU_SPEC(T, int32, 4); \
178   DECLARE_GPU_SPEC(T, int32, 5);
179 
180 #define DECLARE_GPU_CLASS(T)                          \
181   extern template struct ArgMax<GPUDevice, T, int64>; \
182   extern template struct ArgMin<GPUDevice, T, int64>; \
183   extern template struct ArgMax<GPUDevice, T, int32>; \
184   extern template struct ArgMin<GPUDevice, T, int32>;
185 
186 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
187 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
188 
189 #undef DECLARE_GPU_SPECS
190 #undef DECLARE_GPU_CLASS
191 
192 }  // namespace functor
193 
194 // Registration of the GPU implementations.
195 #define REGISTER_ARGMAX_GPU(type)                                   \
196   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
197                               .Device(DEVICE_GPU)                   \
198                               .TypeConstraint<type>("T")            \
199                               .TypeConstraint<int64>("output_type") \
200                               .TypeConstraint<int32>("Tidx")        \
201                               .HostMemory("dimension"),             \
202                           ArgMaxOp<GPUDevice, type, int64>);        \
203   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
204                               .Device(DEVICE_GPU)                   \
205                               .TypeConstraint<type>("T")            \
206                               .TypeConstraint<int64>("output_type") \
207                               .TypeConstraint<int32>("Tidx")        \
208                               .HostMemory("dimension"),             \
209                           ArgMinOp<GPUDevice, type, int64>);        \
210   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
211                               .Device(DEVICE_GPU)                   \
212                               .TypeConstraint<type>("T")            \
213                               .TypeConstraint<int32>("output_type") \
214                               .TypeConstraint<int32>("Tidx")        \
215                               .HostMemory("dimension"),             \
216                           ArgMaxOp<GPUDevice, type, int32>);        \
217   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
218                               .Device(DEVICE_GPU)                   \
219                               .TypeConstraint<type>("T")            \
220                               .TypeConstraint<int32>("output_type") \
221                               .TypeConstraint<int32>("Tidx")        \
222                               .HostMemory("dimension"),             \
223                           ArgMinOp<GPUDevice, type, int32>);
224 
225 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
226 
227 #undef REGISTER_ARGMAX_GPU
228 
229 #endif  // GOOGLE_CUDA
230 
231 }  // namespace tensorflow
232