• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_H_
17 #define TENSORFLOW_CORE_KERNELS_CONV_2D_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor_types.h"
21 #include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
22 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
23 #include "tensorflow/core/util/tensor_format.h"
24 
25 namespace tensorflow {
26 namespace functor {
27 
28 // TODO(yangke): revisit these operations and in particular, see if we can
29 // combine all of them into just one operation without causing nvcc to
30 // timeout.
31 template <typename Device, typename T, int Dims, typename IndexType>
32 struct ShuffleAndReverse {
operatorShuffleAndReverse33   void operator()(const Device& d,
34                   typename TTypes<T, Dims, IndexType>::ConstTensor input,
35                   const Eigen::DSizes<IndexType, Dims>& order,
36                   const Eigen::array<bool, Dims>& reverse_dims,
37                   typename TTypes<T, Dims, IndexType>::Tensor output) {
38     output.device(d) = input.shuffle(order).reverse(reverse_dims);
39   }
40 };
41 
42 template <typename Device, typename T, int Dims, typename IndexType>
43 struct InflatePadAndShuffle {
operatorInflatePadAndShuffle44   void operator()(
45       const Device& d, typename TTypes<T, Dims, IndexType>::ConstTensor input,
46       const Eigen::DSizes<IndexType, Dims>& strides,
47       const Eigen::array<Eigen::IndexPair<IndexType>, Dims>& pad_dims,
48       const Eigen::DSizes<IndexType, Dims>& order,
49       typename TTypes<T, Dims, IndexType>::Tensor output) {
50     output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order);
51   }
52 };
53 
54 template <typename Device, typename Input, typename Filter, typename Output,
55           typename OutputKernel>
SpatialConvolutionFunc(const Device & d,Output output,Input input,Filter filter,int row_stride,int col_stride,int row_dilation,int col_dilation,const Eigen::PaddingType & padding,const OutputKernel & output_kernel)56 void SpatialConvolutionFunc(const Device& d, Output output, Input input,
57                             Filter filter, int row_stride, int col_stride,
58                             int row_dilation, int col_dilation,
59                             const Eigen::PaddingType& padding,
60                             const OutputKernel& output_kernel) {
61   // Need to swap row/col when calling Eigen.
62   output.device(d) =
63       Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
64                                 col_dilation, row_dilation, output_kernel);
65 }
66 
67 template <typename Device, typename T,
68           typename OutputKernel = const Eigen::NoOpOutputKernel>
69 struct SpatialConvolution {
operatorSpatialConvolution70   void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
71                   typename TTypes<T, 4>::ConstTensor input,
72                   typename TTypes<T, 4>::ConstTensor filter, int row_stride,
73                   int col_stride, int row_dilation, int col_dilation,
74                   const Eigen::PaddingType& padding,
75                   const OutputKernel& output_kernel = OutputKernel()) {
76     SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
77                            row_dilation, col_dilation, padding, output_kernel);
78   }
79 };
80 
81 template <typename Device, typename OutputKernel>
82 struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
83   void operator()(const Device& d,
84                   typename TTypes<Eigen::half, 4>::Tensor output,
85                   typename TTypes<Eigen::half, 4>::ConstTensor input,
86                   typename TTypes<Eigen::half, 4>::ConstTensor filter,
87                   int row_stride, int col_stride, int row_dilation,
88                   int col_dilation, const Eigen::PaddingType& padding,
89                   const OutputKernel& output_kernel = OutputKernel()) {
90     output.device(d) =
91         Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
92                                   col_stride, row_stride, padding, col_dilation,
93                                   row_dilation, output_kernel)
94             .template cast<Eigen::half>();
95   }
96 };
97 
98 template <typename Device, typename T>
99 struct SpatialConvolutionBackwardInput {
100   void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
101                   typename TTypes<T, 4>::ConstTensor kernel,
102                   typename TTypes<T, 4>::ConstTensor output_backward,
103                   int row_stride, int col_stride, int row_dilation,
104                   int col_dilation) {
105     // Need to swap row/col when calling Eigen.
106     input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
107         kernel, output_backward, input_backward.dimension(2),
108         input_backward.dimension(1), col_stride, row_stride, col_dilation,
109         row_dilation);
110   }
111 };
112 
113 template <typename Device, typename T>
114 struct SpatialConvolutionBackwardFilter {
115   void operator()(const Device& d,
116                   typename TTypes<T, 4>::Tensor kernel_backward,
117                   typename TTypes<T, 4>::ConstTensor input,
118                   typename TTypes<T, 4>::ConstTensor output_backward,
119                   int row_stride, int col_stride, int row_dilation,
120                   int col_dilation) {
121     // Need to swap row/col when calling Eigen.
122     kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
123         input, output_backward, kernel_backward.dimension(1),
124         kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
125         row_dilation);
126   }
127 };
128 
129 // TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h.
130 // My initial attempt to do this compiled but failed in the pytest
131 // due to a swigdeps error.
132 template <typename Device, typename T,
133           typename OutputKernel = const Eigen::NoOpOutputKernel>
134 struct MatMulConvFunctor {
135   // Computes on device "d": out = in0 * in1, where * is matrix
136   // multiplication.
137   void operator()(
138       const Device& d, typename TTypes<T, 2>::Tensor out,
139       typename TTypes<T, 2>::ConstTensor in0,
140       typename TTypes<T, 2>::ConstTensor in1,
141       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
142       const OutputKernel& output_kernel = OutputKernel()) {
143     out.device(d) = in0.contract(in1, dim_pair, output_kernel);
144   }
145 };
146 
147 // Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
148 //
149 // Note: Currently OIHW is the only supported destination format. Support for
150 // OHWI format will be added in a follow-up change.
151 template <typename Device, typename T, typename IndexType, int NDIMS>
152 struct TransformFilter {
153   void operator()(const Device& d, FilterTensorFormat dst_filter_format,
154                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
155                   typename TTypes<T, NDIMS, IndexType>::Tensor out) {
156     // Merge the spatial dimensions together to speed up the shuffle operation.
157     Eigen::DSizes<IndexType, 3> merged_dims;
158     merged_dims[0] = in.dimension(0);  // spatial dimensions
159     for (int i = 1; i < NDIMS - 2; ++i) {
160       merged_dims[0] *= in.dimension(i);
161     }
162     merged_dims[1] = in.dimension(NDIMS - 2);  // input filters
163     merged_dims[2] = in.dimension(NDIMS - 1);  // output filters
164 
165     DCHECK(dst_filter_format == FORMAT_OIHW)
166         << "Unsupported destination filter format: "
167         << ToString(dst_filter_format);
168     // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
169     // in the beginning.
170     Eigen::DSizes<IndexType, 3> shuffling_perm =
171         Eigen::DSizes<IndexType, 3>(2, 1, 0);
172 
173     Eigen::DSizes<IndexType, NDIMS> expanded_dims;
174     int out_index = 0;
175     for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
176       if (shuffling_perm[merged_dim] == 0) {
177         for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
178           expanded_dims[out_index++] = in.dimension(spatial_dim);
179         }
180       } else {
181         constexpr int kLastSpatialDim = NDIMS - 3;
182         expanded_dims[out_index++] =
183             in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
184       }
185     }
186 
187     out.device(d) =
188         in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
189   }
190 };
191 
192 template <typename Device, typename T, typename IndexType>
193 struct TransformDepth {
194   void operator()(const Device& d,
195                   typename TTypes<T, 4, IndexType>::ConstTensor in,
196                   const Eigen::DSizes<IndexType, 4>& shuffle,
197                   typename TTypes<T, 4, IndexType>::Tensor out) {
198     Eigen::DSizes<IndexType, 3> merged_dims;
199     Eigen::DSizes<IndexType, 4> expanded_dims;
200     Eigen::DSizes<IndexType, 3> new_shuffle;
201 
202     // Merge dimensions that won't be shuffled together to speed things up.
203     if (shuffle[1] == 2 && shuffle[2] == 3) {
204       merged_dims[0] = in.dimension(0);
205       merged_dims[1] = in.dimension(1);
206       merged_dims[2] = in.dimension(2) * in.dimension(3);
207       new_shuffle[0] = shuffle[0];
208       new_shuffle[1] = 2;
209       new_shuffle[2] = shuffle[3];
210       expanded_dims[0] = in.dimension(shuffle[0]);
211       expanded_dims[1] = in.dimension(2);
212       expanded_dims[2] = in.dimension(3);
213       expanded_dims[3] = in.dimension(shuffle[3]);
214     } else if (shuffle[0] == 2 && shuffle[1] == 3) {
215       merged_dims[0] = in.dimension(0);
216       merged_dims[1] = in.dimension(1);
217       merged_dims[2] = in.dimension(2) * in.dimension(3);
218       new_shuffle[0] = 2;
219       new_shuffle[1] = shuffle[2];
220       new_shuffle[2] = shuffle[3];
221       expanded_dims[0] = in.dimension(2);
222       expanded_dims[1] = in.dimension(3);
223       expanded_dims[2] = in.dimension(shuffle[2]);
224       expanded_dims[3] = in.dimension(shuffle[3]);
225     } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
226                shuffle[3] == 2) {
227       merged_dims[0] = in.dimension(0);
228       merged_dims[1] = in.dimension(1) * in.dimension(2);
229       merged_dims[2] = in.dimension(3);
230       new_shuffle[0] = 0;
231       new_shuffle[1] = 2;
232       new_shuffle[2] = 1;
233       expanded_dims[0] = in.dimension(0);
234       expanded_dims[1] = in.dimension(3);
235       expanded_dims[2] = in.dimension(1);
236       expanded_dims[3] = in.dimension(2);
237     } else {
238       assert(false && "unexpected shuffle");
239     }
240 
241     out.device(d) =
242         in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
243   }
244 };
245 
246 template <typename Device, typename T, typename IndexType, int NDIMS>
247 struct PadInput {
248   void operator()(const Device& d,
249                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
250                   const std::array<int, NDIMS - 2>& padding_left,
251                   const std::array<int, NDIMS - 2>& padding_right,
252                   typename TTypes<T, NDIMS, IndexType>::Tensor out,
253                   TensorFormat format) {
254     Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding;
255     padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0};
256     for (int i = 0; i < NDIMS - 2; ++i) {
257       padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] = {
258           padding_left[i], padding_right[i]};
259     }
260     padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0};
261     out.device(d) = in.pad(padding);
262   }
263 };
264 
265 // Converts a tensor from:
266 //   [batch, <spatial>, filters]
267 // to:
268 //   [batch, filters, <spatial>]
269 template <typename Device, typename T, int NDIMS>
270 struct NHWCToNCHW {
271   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
272                   typename TTypes<T, NDIMS>::Tensor out);
273 };
274 
275 // Converts a tensor from:
276 //   [batch, filters, <spatial>]
277 // to:
278 //   [batch, <spatial>, filters]
279 template <typename Device, typename T, int NDIMS>
280 struct NCHWToNHWC {
281   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
282                   typename TTypes<T, NDIMS>::Tensor out);
283 };
284 
285 // Converts a tensor from:
286 //   [dim0, dim1, dim2]
287 // to:
288 //   [dim0, dim2, dim1]
289 template <typename Device, typename T, bool conjugate = false>
290 struct SwapDimension1And2InTensor3 {
291   void operator()(const Device& d, const T* in,
292                   const gtl::ArraySlice<int64>& input_dims, T* out);
293 };
294 
295 // Converts a tensor from:
296 //   [dim0, dim1, dim2]
297 // to:
298 //   [dim2, dim1, dim0]
299 template <typename Device, typename T, bool conjugate = false>
300 struct SwapDimension0And2InTensor3 {
301   void operator()(const Device& d, const T* in,
302                   const gtl::ArraySlice<int64>& input_dims, T* out);
303 };
304 
305 // Transforms back filter from OIHW to HWOI format to reverse effect of
306 // TransformFilter above.
307 // TODO(hinsu): Support reverse transformation from filter format OHWI as well.
308 template <typename Device, typename T, int NDIMS>
309 struct ReverseTransformFilter {
310   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
311                   typename TTypes<T, NDIMS>::Tensor out);
312 };
313 
314 }  // namespace functor
315 
316 template <class T>
317 class ConvAlgorithmMap;
318 
319 template <>
320 class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
321 }  // namespace tensorflow
322 
323 #endif  // TENSORFLOW_CORE_KERNELS_CONV_2D_H_
324