• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/one_hot_op.h"
25 
26 #include <memory>
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/overflow.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename T, typename TI>
44 class OneHotOp : public OpKernel {
45  public:
OneHotOp(OpKernelConstruction * ctx)46   explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
47     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
48   }
49 
Compute(OpKernelContext * ctx)50   void Compute(OpKernelContext* ctx) override {
51     const Tensor& indices = ctx->input(0);
52     const Tensor& depth = ctx->input(1);
53     const Tensor& on_value = ctx->input(2);
54     const Tensor& off_value = ctx->input(3);
55     const TensorShape& indices_shape = indices.shape();
56 
57     const int indices_dims = indices_shape.dims();
58     const int output_dims = indices_dims + 1;
59 
60     // Preliminary validation of sizes.
61     OP_REQUIRES(
62         ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
63         errors::InvalidArgument("Expected axis to be -1 or between [0, ",
64                                 output_dims, ").  But received: ", axis_));
65     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth.shape()),
66                 errors::InvalidArgument("depth must be a scalar, but got: ",
67                                         depth.shape().DebugString()));
68     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value.shape()),
69                 errors::InvalidArgument("on_value must be a scalar, but got: ",
70                                         on_value.shape().DebugString()));
71     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value.shape()),
72                 errors::InvalidArgument("off_value must be a scalar, but got: ",
73                                         off_value.shape().DebugString()));
74 
75     const int axis = (axis_ == -1) ? indices_dims : axis_;
76 
77     // The one-hot dimension.
78     const int32 depth_v = depth.scalar<int32>()();
79     OP_REQUIRES(
80         ctx, depth_v >= 0,
81         errors::InvalidArgument("depth must be non-negative, got: ", depth_v));
82     OP_REQUIRES(
83         ctx,
84         MultiplyWithoutOverflow(indices_shape.num_elements(), depth_v) >= 0,
85         errors::InvalidArgument("OneHot result would have shape ",
86                                 indices_shape.DebugString(), " + [", depth_v,
87                                 "], which exceeds 2**63 - 1 elements"));
88 
89     TensorShape output_shape = indices_shape;
90     output_shape.InsertDim(axis, depth_v);
91 
92     auto on_value_t = on_value.scalar<T>();
93     auto off_value_t = off_value.scalar<T>();
94 
95     Tensor* output;
96     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
97 
98     if (output_shape.num_elements() > 0) {
99       // prefix_dim_size == # of elements before the axis
100       // depth_v == # of elements per axis
101       // suffix_dim_size == # of elements after the axis
102       int64 prefix_dim_size = 1;
103       for (int i = 0; i < axis; ++i) {
104         prefix_dim_size *= indices_shape.dim_size(i);
105       }
106       TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size;
107 
108       // Split indices into matrix of size prefix_dim_size x suffix_dim_size
109       auto indices_t =
110           indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
111       // Split output into 3-Tensor of size:
112       //   prefix_dim_size x depth x suffix_dim_size.
113       auto output_t =
114           output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
115 
116       functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(),
117                                               indices_t, on_value_t,
118                                               off_value_t, &output_t);
119     }
120   }
121 
122  private:
123   int32 axis_;
124 
125   TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
126 };
127 
128 #define REGISTER_ONE_HOT_INDEX(type, index_type)                \
129   REGISTER_KERNEL_BUILDER(Name("OneHot")                        \
130                               .Device(DEVICE_CPU)               \
131                               .TypeConstraint<index_type>("TI") \
132                               .TypeConstraint<type>("T")        \
133                               .HostMemory("depth"),             \
134                           OneHotOp<CPUDevice, type, index_type>);
135 
136 #define REGISTER_ONE_HOT(type)         \
137   REGISTER_ONE_HOT_INDEX(type, uint8); \
138   REGISTER_ONE_HOT_INDEX(type, int32); \
139   REGISTER_ONE_HOT_INDEX(type, int64)
140 
141 TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
142 
143 #if GOOGLE_CUDA
144 
145 // Forward declarations of the functor specializations for GPU.
146 namespace functor {
147 #define DECLARE_GPU_SPEC_INDEX(T, TI)                                      \
148   template <>                                                              \
149   void OneHot<GPUDevice, T, TI>::Compute(                                  \
150       const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
151       const typename TTypes<T>::ConstScalar& on_value,                     \
152       const typename TTypes<T>::ConstScalar& off_value,                    \
153       typename TTypes<T, 3>::Tensor* output);                              \
154   extern template struct OneHot<GPUDevice, T, TI>;
155 
156 #define DECLARE_GPU_SPEC(T)         \
157   DECLARE_GPU_SPEC_INDEX(T, uint8); \
158   DECLARE_GPU_SPEC_INDEX(T, int32); \
159   DECLARE_GPU_SPEC_INDEX(T, int64);
160 
161 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
162 TF_CALL_bool(DECLARE_GPU_SPEC);
163 TF_CALL_int32(DECLARE_GPU_SPEC);
164 TF_CALL_int64(DECLARE_GPU_SPEC);
165 
166 #undef DECLARE_GPU_SPEC_INDEX
167 #undef DECLARE_GPU_SPEC
168 
169 }  // namespace functor
170 
171 // Registration of the GPU implementations.
172 #define REGISTER_ONE_HOT_GPU_INDEX(type, index_type)            \
173   REGISTER_KERNEL_BUILDER(Name("OneHot")                        \
174                               .Device(DEVICE_GPU)               \
175                               .TypeConstraint<index_type>("TI") \
176                               .TypeConstraint<type>("T")        \
177                               .HostMemory("depth"),             \
178                           OneHotOp<GPUDevice, type, index_type>);
179 
180 #define REGISTER_ONE_HOT_GPU(type)         \
181   REGISTER_ONE_HOT_GPU_INDEX(type, uint8); \
182   REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
183   REGISTER_ONE_HOT_GPU_INDEX(type, int64);
184 
185 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
186 TF_CALL_bool(REGISTER_ONE_HOT_GPU);
187 TF_CALL_int32(REGISTER_ONE_HOT_GPU);
188 TF_CALL_int64(REGISTER_ONE_HOT_GPU);
189 
190 #undef REGISTER_ONE_HOT_GPU_INDEX
191 #undef REGISTER_ONE_HOT_GPU
192 
193 #endif  // GOOGLE_CUDA
194 
195 }  // namespace tensorflow
196