1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/data_format_ops.h" 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::ThreadPoolDevice CPUDevice; 29 typedef Eigen::GpuDevice GPUDevice; 30 31 template <typename Device, typename T> 32 class DataFormatDimMapOp : public OpKernel { 33 public: DataFormatDimMapOp(OpKernelConstruction * context)34 explicit DataFormatDimMapOp(OpKernelConstruction* context) 35 : OpKernel(context) { 36 string src_format; 37 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); 38 string dst_format; 39 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); 40 OP_REQUIRES(context, src_format.size() == 4, 41 errors::InvalidArgument(strings::StrCat( 42 "Source format must of length 4, received src_format = ", 43 src_format))); 44 OP_REQUIRES( 45 context, dst_format.size() == 4, 46 errors::InvalidArgument(strings::StrCat( 47 "Destination format must of length 4, received dst_format = ", 48 dst_format))); 49 dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())}); 50 for (int i = 0; i < src_format.size(); ++i) { 51 for (int j = 0; j < dst_format.size(); ++j) { 52 if (dst_format[j] == src_format[i]) { 53 dst_idx_.vec<int>()(i) = j; 54 break; 55 } 56 } 57 } 58 } 59 Compute(OpKernelContext * context)60 void Compute(OpKernelContext* context) override { 61 const Tensor& input = context->input(0); 62 Tensor* output; 63 OP_REQUIRES_OK(context, 64 context->allocate_output(0, input.shape(), &output)); 65 functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), 66 input.flat<T>(), output->flat<T>(), 67 dst_idx_.vec<int>()); 68 } 69 70 Tensor dst_idx_; 71 }; 72 73 template <typename Device, typename T> 74 class DataFormatVecPermuteOp : public OpKernel { 75 public: DataFormatVecPermuteOp(OpKernelConstruction * context)76 explicit DataFormatVecPermuteOp(OpKernelConstruction* context) 77 : OpKernel(context) { 78 string src_format; 79 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); 80 string dst_format; 81 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); 82 src_format_ = src_format; 83 dst_format_ = dst_format; 84 } 85 Compute(OpKernelContext * context)86 void Compute(OpKernelContext* context) override { 87 const Tensor& input = context->input(0); 88 OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2, 89 errors::InvalidArgument( 90 "input must be a vector or 2D tensor, but got shape ", 91 input.shape().DebugString())); 92 if (input.dims() == 1) { 93 OP_REQUIRES( 94 context, input.NumElements() == 4, 95 errors::InvalidArgument("1D input must be of size 4, but got shape ", 96 input.shape().DebugString())); 97 } else if (input.dims() == 2) { 98 OP_REQUIRES( 99 context, input.dim_size(0) == 4, 100 errors::InvalidArgument( 101 "First dimension of 2D input must be of size 4, but got shape ", 102 input.shape().DebugString())); 103 OP_REQUIRES( 104 context, input.dim_size(1) == 2, 105 errors::InvalidArgument( 106 "Second dimension of 2D input must be of size 2, but got shape ", 107 input.shape().DebugString())); 108 } 109 110 Tensor* output = nullptr; 111 OP_REQUIRES_OK(context, 112 context->allocate_output(0, input.shape(), &output)); 113 // Support 1D and 2D cases. 114 Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx; 115 ComputeDstIndex(input.dims(), &dst_idx); 116 117 functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(), 118 input.flat<T>(), 119 output->flat<T>(), dst_idx); 120 } 121 122 private: 123 // Finds out the destination index. Support 1D and 2D cases. 124 // Example: HWNC --> NHWC 125 // 1D: dst = [1, 2, 0, 3], 126 // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7] ComputeDstIndex(int num_dim,Eigen::DSizes<Eigen::DenseIndex,8> * dst)127 void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) { 128 for (int i = 0; i < src_format_.size(); ++i) { 129 for (int j = 0; j < dst_format_.size(); ++j) { 130 if (dst_format_[j] != src_format_[i]) continue; 131 // Found the dst index. Set output based on the number of dims. 132 for (int k = 0; k < num_dim; ++k) { 133 (*dst)[i * num_dim + k] = j * num_dim + k; 134 } 135 } 136 } 137 } 138 139 string src_format_; 140 string dst_format_; 141 }; 142 143 #define REGISTER_KERNEL(T) \ 144 REGISTER_KERNEL_BUILDER( \ 145 Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 146 DataFormatDimMapOp<CPUDevice, T>); 147 TF_CALL_int32(REGISTER_KERNEL); 148 TF_CALL_int64(REGISTER_KERNEL); 149 #undef REGISTER_KERNEL 150 151 #define REGISTER_KERNEL(T) \ 152 REGISTER_KERNEL_BUILDER( \ 153 Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 154 DataFormatVecPermuteOp<CPUDevice, T>); 155 TF_CALL_int32(REGISTER_KERNEL); 156 TF_CALL_int64(REGISTER_KERNEL); 157 #undef REGISTER_KERNEL 158 159 #define REGISTER_KERNEL(T) \ 160 REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \ 161 .Device(DEVICE_CPU) \ 162 .Label("host") \ 163 .TypeConstraint<T>("T"), \ 164 DataFormatVecPermuteOp<CPUDevice, T>); 165 TF_CALL_int32(REGISTER_KERNEL); 166 TF_CALL_int64(REGISTER_KERNEL); 167 #undef REGISTER_KERNEL 168 169 #if GOOGLE_CUDA 170 // Forward declarations of the functor specializations for GPU. 171 namespace functor { 172 #define DECLARE_GPU_SPEC(T) \ 173 template <> \ 174 void DataFormatDimMap<GPUDevice, T>::operator()( \ 175 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ 176 typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \ 177 extern template struct DataFormatDimMap<GPUDevice, T>; 178 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 179 TF_CALL_int32(DECLARE_GPU_SPECS); 180 TF_CALL_int64(DECLARE_GPU_SPECS); 181 #undef DECLARE_GPU_SPEC 182 183 #define DECLARE_GPU_SPEC(T) \ 184 template <> \ 185 void DataFormatVecPermute<GPUDevice, T>::operator()( \ 186 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ 187 typename TTypes<T>::Vec y, \ 188 const Eigen::DSizes<Eigen::DenseIndex, 8>& dst_idx); \ 189 extern template struct DataFormatVecPermute<GPUDevice, T>; 190 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 191 TF_CALL_int32(DECLARE_GPU_SPECS); 192 TF_CALL_int64(DECLARE_GPU_SPECS); 193 #undef DECLARE_GPU_SPEC 194 } // namespace functor 195 196 // Registration of the GPU implementations. 197 #define REGISTER_GPU_KERNEL(T) \ 198 REGISTER_KERNEL_BUILDER( \ 199 Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 200 DataFormatDimMapOp<GPUDevice, T>); 201 TF_CALL_int32(REGISTER_GPU_KERNEL); 202 TF_CALL_int64(REGISTER_GPU_KERNEL); 203 #undef REGISTER_GPU_KERNEL 204 205 #define REGISTER_GPU_KERNEL(T) \ 206 REGISTER_KERNEL_BUILDER( \ 207 Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 208 DataFormatVecPermuteOp<GPUDevice, T>); \ 209 REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \ 210 .Device(DEVICE_GPU) \ 211 .HostMemory("x") \ 212 .HostMemory("y") \ 213 .Label("host") \ 214 .TypeConstraint<T>("T"), \ 215 DataFormatVecPermuteOp<CPUDevice, T>); 216 TF_CALL_int32(REGISTER_GPU_KERNEL); 217 TF_CALL_int64(REGISTER_GPU_KERNEL); 218 #undef REGISTER_GPU_KERNEL 219 #endif // GOOGLE_CUDA 220 221 } // namespace tensorflow 222