• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/data_format_ops.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 
26 namespace tensorflow {
27 
28 typedef Eigen::ThreadPoolDevice CPUDevice;
29 typedef Eigen::GpuDevice GPUDevice;
30 
31 template <typename Device, typename T>
32 class DataFormatDimMapOp : public OpKernel {
33  public:
DataFormatDimMapOp(OpKernelConstruction * context)34   explicit DataFormatDimMapOp(OpKernelConstruction* context)
35       : OpKernel(context) {
36     string src_format;
37     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
38     string dst_format;
39     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
40     OP_REQUIRES(context, src_format.size() == 4,
41                 errors::InvalidArgument(strings::StrCat(
42                     "Source format must of length 4, received src_format = ",
43                     src_format)));
44     OP_REQUIRES(
45         context, dst_format.size() == 4,
46         errors::InvalidArgument(strings::StrCat(
47             "Destination format must of length 4, received dst_format = ",
48             dst_format)));
49     dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
50     for (int i = 0; i < src_format.size(); ++i) {
51       for (int j = 0; j < dst_format.size(); ++j) {
52         if (dst_format[j] == src_format[i]) {
53           dst_idx_.vec<int>()(i) = j;
54           break;
55         }
56       }
57     }
58   }
59 
Compute(OpKernelContext * context)60   void Compute(OpKernelContext* context) override {
61     const Tensor& input = context->input(0);
62     Tensor* output;
63     OP_REQUIRES_OK(context,
64                    context->allocate_output(0, input.shape(), &output));
65     functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
66                                            input.flat<T>(), output->flat<T>(),
67                                            dst_idx_.vec<int>());
68   }
69 
70   Tensor dst_idx_;
71 };
72 
73 template <typename Device, typename T>
74 class DataFormatVecPermuteOp : public OpKernel {
75  public:
DataFormatVecPermuteOp(OpKernelConstruction * context)76   explicit DataFormatVecPermuteOp(OpKernelConstruction* context)
77       : OpKernel(context) {
78     string src_format;
79     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
80     string dst_format;
81     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
82     src_format_ = src_format;
83     dst_format_ = dst_format;
84   }
85 
Compute(OpKernelContext * context)86   void Compute(OpKernelContext* context) override {
87     const Tensor& input = context->input(0);
88     OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2,
89                 errors::InvalidArgument(
90                     "input must be a vector or 2D tensor, but got shape ",
91                     input.shape().DebugString()));
92     if (input.dims() == 1) {
93       OP_REQUIRES(
94           context, input.NumElements() == 4,
95           errors::InvalidArgument("1D input must be of size 4, but got shape ",
96                                   input.shape().DebugString()));
97     } else if (input.dims() == 2) {
98       OP_REQUIRES(
99           context, input.dim_size(0) == 4,
100           errors::InvalidArgument(
101               "First dimension of 2D input must be of size 4, but got shape ",
102               input.shape().DebugString()));
103       OP_REQUIRES(
104           context, input.dim_size(1) == 2,
105           errors::InvalidArgument(
106               "Second dimension of 2D input must be of size 2, but got shape ",
107               input.shape().DebugString()));
108     }
109 
110     Tensor* output = nullptr;
111     OP_REQUIRES_OK(context,
112                    context->allocate_output(0, input.shape(), &output));
113     // Support 1D and 2D cases.
114     Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
115     ComputeDstIndex(input.dims(), &dst_idx);
116 
117     functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
118                                                input.flat<T>(),
119                                                output->flat<T>(), dst_idx);
120   }
121 
122  private:
123   // Finds out the destination index. Support 1D and 2D cases.
124   // Example: HWNC --> NHWC
125   // 1D: dst = [1, 2, 0, 3],
126   // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
ComputeDstIndex(int num_dim,Eigen::DSizes<Eigen::DenseIndex,8> * dst)127   void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
128     for (int i = 0; i < src_format_.size(); ++i) {
129       for (int j = 0; j < dst_format_.size(); ++j) {
130         if (dst_format_[j] != src_format_[i]) continue;
131         // Found the dst index. Set output based on the number of dims.
132         for (int k = 0; k < num_dim; ++k) {
133           (*dst)[i * num_dim + k] = j * num_dim + k;
134         }
135       }
136     }
137   }
138 
139   string src_format_;
140   string dst_format_;
141 };
142 
143 #define REGISTER_KERNEL(T)                                                \
144   REGISTER_KERNEL_BUILDER(                                                \
145       Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
146       DataFormatDimMapOp<CPUDevice, T>);
147 TF_CALL_int32(REGISTER_KERNEL);
148 TF_CALL_int64(REGISTER_KERNEL);
149 #undef REGISTER_KERNEL
150 
151 #define REGISTER_KERNEL(T)                                                    \
152   REGISTER_KERNEL_BUILDER(                                                    \
153       Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
154       DataFormatVecPermuteOp<CPUDevice, T>);
155 TF_CALL_int32(REGISTER_KERNEL);
156 TF_CALL_int64(REGISTER_KERNEL);
157 #undef REGISTER_KERNEL
158 
159 #define REGISTER_KERNEL(T)                             \
160   REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \
161                               .Device(DEVICE_CPU)      \
162                               .Label("host")           \
163                               .TypeConstraint<T>("T"), \
164                           DataFormatVecPermuteOp<CPUDevice, T>);
165 TF_CALL_int32(REGISTER_KERNEL);
166 TF_CALL_int64(REGISTER_KERNEL);
167 #undef REGISTER_KERNEL
168 
169 #if GOOGLE_CUDA
170 // Forward declarations of the functor specializations for GPU.
171 namespace functor {
172 #define DECLARE_GPU_SPEC(T)                                    \
173   template <>                                                  \
174   void DataFormatDimMap<GPUDevice, T>::operator()(             \
175       const GPUDevice& d, typename TTypes<T>::ConstFlat x,     \
176       typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \
177   extern template struct DataFormatDimMap<GPUDevice, T>;
178 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
179 TF_CALL_int32(DECLARE_GPU_SPECS);
180 TF_CALL_int64(DECLARE_GPU_SPECS);
181 #undef DECLARE_GPU_SPEC
182 
183 #define DECLARE_GPU_SPEC(T)                                \
184   template <>                                              \
185   void DataFormatVecPermute<GPUDevice, T>::operator()(     \
186       const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
187       typename TTypes<T>::Vec y,                           \
188       const Eigen::DSizes<Eigen::DenseIndex, 8>& dst_idx); \
189   extern template struct DataFormatVecPermute<GPUDevice, T>;
190 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
191 TF_CALL_int32(DECLARE_GPU_SPECS);
192 TF_CALL_int64(DECLARE_GPU_SPECS);
193 #undef DECLARE_GPU_SPEC
194 }  // namespace functor
195 
196 // Registration of the GPU implementations.
197 #define REGISTER_GPU_KERNEL(T)                                            \
198   REGISTER_KERNEL_BUILDER(                                                \
199       Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
200       DataFormatDimMapOp<GPUDevice, T>);
201 TF_CALL_int32(REGISTER_GPU_KERNEL);
202 TF_CALL_int64(REGISTER_GPU_KERNEL);
203 #undef REGISTER_GPU_KERNEL
204 
205 #define REGISTER_GPU_KERNEL(T)                                                \
206   REGISTER_KERNEL_BUILDER(                                                    \
207       Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
208       DataFormatVecPermuteOp<GPUDevice, T>);                                  \
209   REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute")                        \
210                               .Device(DEVICE_GPU)                             \
211                               .HostMemory("x")                                \
212                               .HostMemory("y")                                \
213                               .Label("host")                                  \
214                               .TypeConstraint<T>("T"),                        \
215                           DataFormatVecPermuteOp<CPUDevice, T>);
216 TF_CALL_int32(REGISTER_GPU_KERNEL);
217 TF_CALL_int64(REGISTER_GPU_KERNEL);
218 #undef REGISTER_GPU_KERNEL
219 #endif  // GOOGLE_CUDA
220 
221 }  // namespace tensorflow
222