• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains a set of different implementations of the two-dimensional
17 // convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor
18 // to implement the computation, but this module has a variety of different ways
19 // of producing the same result. These methods are designed to be easier to
20 // understand and connect to other libraries, so that we can take advantage of
21 // platforms that have specialized implementations of GEMM for example.
22 //
23 // The basic interface is a Conv functor object that's templated by the types
24 // of the data it will be operating on, and is passed in the arguments needed to
25 // calculate the convolution. The simplest implementation of this functor is
26 // ReferenceConvFunctor, which is a readable but slow reference version.
27 //
28 // A faster version uses the approach of packing image patches into a matrix
29 // before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can
30 // use a variety of different methods to calculate the matrix multiplication,
31 // or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the
32 // FastGemmFunctor will use whatever optimized libraries are available. By
33 // default it uses Eigen, but on Apple platforms it will take advantage of the
34 // system's Accelerate BLAS library to get better performance than the standard
35 // TensorFlow convolution kernel.
36 //
37 // The version actually used is defined at the bottom of this file using the
38 // REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for
39 // example to switch to a reference one for easier debugging) you can swap out
40 // the default functors in that call.
41 //
42 // The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS
43 // makefile build defines this, but if you want to enable this implementation
44 // and disable the standard EigenTensor one in other build setups, you'll need
45 // to define it there too.
46 
47 #define EIGEN_USE_THREADS
48 
49 #include <string.h>
50 
51 #include <map>
52 #include <vector>
53 
54 #include "tensorflow/core/framework/bounds_check.h"
55 #include "tensorflow/core/framework/kernel_shape_util.h"
56 #include "tensorflow/core/framework/numeric_op.h"
57 #include "tensorflow/core/framework/op_kernel.h"
58 #include "tensorflow/core/framework/register_types.h"
59 #include "tensorflow/core/framework/resource_mgr.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/framework/tensor_shape.h"
62 #include "tensorflow/core/framework/tensor_slice.h"
63 #include "tensorflow/core/kernels/conv_ops.h"
64 #include "tensorflow/core/kernels/gemm_functors.h"
65 #include "tensorflow/core/util/image_resizer_state.h"
66 #include "tensorflow/core/util/mirror_pad_mode.h"
67 #include "tensorflow/core/util/padding.h"
68 #include "tensorflow/core/util/tensor_format.h"
69 
70 namespace tensorflow {
71 
72 namespace {
73 // This function implements the convolution operation in as simple a form as
74 // possible. It won't give great performance, but it is very useful for
75 // stepping through and instrumenting for debugging, creating minimal benchmarks
76 // to prototype with, and sharing with teams that want to run this outside of
77 // our environment.
78 // With that in mind, I've avoided using anything except pretty standard C++
79 // types. This is especially noticeable in the data access through raw array
80 // indexing. It's deliberate in this case though, since it makes the underlying
81 // memory order very explicit, which is important for both inspecting memory
82 // contents during debugging and for specifying what we expect to others.
83 // The memory layout of the data is, from biggest stride to smallest:
84 // input_data = [input_batches, input_height, input_width, input_depth]
85 // filter_data = [filter_height, filter_width, input_depth, filter_count]
86 // output_data = [input_batches, output_height, output_width, filter_count]
87 template <class T1, class T2, class T3>
88 class ReferenceConvFunctor {
89  public:
operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width)90   void operator()(OpKernelContext* context, const T1* input_data,
91                   int input_batches, int input_height, int input_width,
92                   int input_depth, const T2* filter_data, int filter_height,
93                   int filter_width, int filter_count, int stride_rows,
94                   int stride_cols, Padding padding, T3* output_data,
95                   int output_height, int output_width) {
96     // The two different padding modes we support can be a bit confusing. SAME
97     // means we're trying to produce an output image that's the same size as the
98     // input. It's complicated by stride, which shrinks the output image by a
99     // a factor, but it means we end up sampling from outside the borders of the
100     // input. These out-of-bounds values are read as zeroes. VALID means only
101     // produce output values where the filters can read all their values from
102     // within the input image. It effectively removes the margins of the output
103     // image compared to the one produced by SAME. Stride complicates this
104     // definition though, because it can result in the right and bottom filter
105     // patches sampling from outside the borders if it's greater than 1.
106     // Most of the logic for sorting this all out is done before this function,
107     // when we calculate the output size, but the positioning of the origin of
108     // the filters is different between the two modes, since SAME positions the
109     // first filter off the edge of the input.
110     int filter_left_offset;
111     int filter_top_offset;
112     if (padding == VALID) {
113       filter_left_offset =
114           ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
115           2;
116       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
117                            input_height + 1) /
118                           2;
119     } else {
120       filter_left_offset =
121           ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
122       filter_top_offset =
123           ((output_height - 1) * stride_rows + filter_height - input_height) /
124           2;
125     }
126 
127     // If we've got multiple images in our input, work through each of them.
128     for (int batch = 0; batch < input_batches; ++batch) {
129       // Walk through all the output image values, sliding the filter to
130       // different positions in the input.
131       for (int out_y = 0; out_y < output_height; ++out_y) {
132         for (int out_x = 0; out_x < output_width; ++out_x) {
133           // Each filter kernel produces one output channel.
134           for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
135             // We're going to calculate a single output value, which means we
136             // need to multiply a three dimensional kernel of weights against
137             // the current location within the input image.
138             /*
139              *-------------------------------...
140              |\ ^
141              | \in_depth
142              |  \ v
143              |   *-------------------------------...
144              |   |            ^
145              |   |       in_y_origin
146              |   |            v   \
147              |   |<in_x_origin>*---*^
148              |   |            \|   |filter_height
149              .   |             *---*v
150              .   |             <--->
151              .         filter_width
152              .
153             */
154             const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
155             const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
156             T3 total(0);
157             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
158               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
159                 for (int in_channel = 0; in_channel < input_depth;
160                      ++in_channel) {
161                   const int in_x = in_x_origin + filter_x;
162                   const int in_y = in_y_origin + filter_y;
163                   T1 input_value;
164                   // If the location is outside the bounds of the input image,
165                   // use zero as a default value.
166                   if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
167                       (in_y < input_height)) {
168                     input_value =
169                         input_data[(batch * input_height * input_width *
170                                     input_depth) +
171                                    (in_y * input_width * input_depth) +
172                                    (in_x * input_depth) + in_channel];
173                   } else {
174                     input_value = T1(0);
175                   }
176                   const T2 filter_value =
177                       filter_data[(filter_y * filter_width * input_depth *
178                                    filter_count) +
179                                   (filter_x * input_depth * filter_count) +
180                                   (in_channel * filter_count) + out_channel];
181                   total += (input_value * filter_value);
182                 }
183               }
184             }
185             output_data[(batch * output_height * output_width * filter_count) +
186                         (out_y * output_width * filter_count) +
187                         (out_x * filter_count) + out_channel] = total;
188           }
189         }
190       }
191     }
192   }
193 };
194 
195 // We don't want to allocate a buffer to hold all the patches if the size is
196 // going to be extremely large, so break it into chunks if it's bigger than
197 // a limit. Each chunk will be processed serially, so we can refill the
198 // buffer for the next chunk and reuse it, keeping maximum memory size down.
199 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
200 // other platforms using Eigen, and 1MB for Apple devices, from experimentation.
201 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
202 const size_t kMaxChunkSize = (1 * 1024 * 1024);
203 #else
204 const size_t kMaxChunkSize = (16 * 1024 * 1024);
205 #endif
206 
207 // Implements convolution as a two stage process, first packing the patches of
208 // the input image into columns (im2col) and then running GEMM to produce the
209 // final result.
210 template <class T1, class T2, class T3, class TGemmFunctor>
211 class Im2ColConvFunctor {
212  public:
operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width)213   void operator()(OpKernelContext* context, const T1* input_data,
214                   int input_batches, int input_height, int input_width,
215                   int input_depth, const T2* filter_data, int filter_height,
216                   int filter_width, int filter_count, int stride_rows,
217                   int stride_cols, Padding padding, T3* output_data,
218                   int output_height, int output_width) {
219     if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) ||
220         (input_depth <= 0)) {
221       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
222                    << input_batches << ", " << input_height << ", "
223                    << input_width << ", " << input_depth;
224       return;
225     }
226     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
227       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
228                    << filter_width << ", " << filter_height << ", "
229                    << filter_count;
230       return;
231     }
232     if ((output_width <= 0) || (output_height <= 0)) {
233       LOG(WARNING) << "Conv2D was called with bad output width or height: "
234                    << output_width << ", " << output_height;
235       return;
236     }
237 
238     // We can just use a GEMM if the im2col is the identity operator, e.g., if
239     // the kernel is 1x1 or the input data and filter have same height/width.
240     if (filter_height == 1 && filter_width == 1 && stride_rows == 1 &&
241         stride_cols == 1) {
242       // The kernel is 1x1.
243       const int m = input_batches * input_height * input_width;
244       const int n = filter_count;
245       const int k = input_depth;
246       const int lda = k;
247       const int ldb = filter_count;
248       const int ldc = filter_count;
249       TGemmFunctor gemm_functor;
250       gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
251                    output_data, ldc);
252       return;
253     } else if (filter_height == input_height && filter_width == input_width &&
254                padding == VALID) {
255       // The input data and filter have the same height/width.
256       const int m = input_batches;
257       const int n = filter_count;
258       const int k = input_height * input_width * input_depth;
259       const int lda = k;
260       const int ldb = filter_count;
261       const int ldc = filter_count;
262       TGemmFunctor gemm_functor;
263       gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
264                    output_data, ldc);
265       return;
266     }
267 
268     // These calculations define how the patches will be positioned within the
269     // input image. The actual definitions are quite complex, and rely on the
270     // previously-calculated output size.
271     int filter_left_offset;
272     int filter_top_offset;
273     if (padding == VALID) {
274       filter_left_offset =
275           ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
276           2;
277       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
278                            input_height + 1) /
279                           2;
280     } else {
281       filter_left_offset =
282           ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
283       filter_top_offset =
284           ((output_height - 1) * stride_rows + filter_height - input_height) /
285           2;
286     }
287 
288     // The im2col buffer has # of patches rows, and # of filters cols.
289     // It's laid out like this, in row major order in memory:
290     //        < filter value count >
291     //   ^   +---------------------+
292     // patch |                     |
293     // count |                     |
294     //   v   +---------------------+
295     // Each patch row contains a filter_width x filter_height patch of the
296     // input, with the depth channel as the most contiguous in memory, followed
297     // by the width, then the height. This is the standard memory order in the
298     // image world if it helps to visualize it.
299     const int filter_value_count = filter_width * filter_height * input_depth;
300     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
301                 errors::InvalidArgument("Im2Col patch too large for buffer"));
302     const int64_t patches_per_chunk =
303         kMaxChunkSize / (filter_value_count * sizeof(T1));
304     const int64_t chunk_value_count =
305         (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
306     // Because memory allocation is very expensive on mobile platforms, try to
307     // allocate a persistent buffer that will be kept around between calls. We
308     // use TensorFlow's resource management to ensure that the memory will be
309     // released when the session is over.
310     Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
311     std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
312         creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
313           *resource = new Im2ColBufferResource<T1, chunk_value_count>();
314           return Status::OK();
315         };
316     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
317                                 "Conv2d", "im2col_buffer",
318                                 &im2col_buffer_resource, creator));
319     // This means that multiple ops can't be run simultaneously on different
320     // threads, because we have a single shared resource. The platforms this is
321     // aimed at have intra-op parallelism as their focus though, so it shouldn't
322     // be an issue.
323     mutex_lock lock_buffer(im2col_buffer_resource->mu);
324     core::ScopedUnref unref_buffer(im2col_buffer_resource);
325     T1* im2col_buffer = im2col_buffer_resource->data;
326 
327     const int64_t patch_count = (input_batches * output_height * output_width);
328     const int64_t chunk_count =
329         (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
330     for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
331       const int64_t patch_index_start = chunk_index * patches_per_chunk;
332       const int64_t patch_index_end =
333           std::min(patch_index_start + patches_per_chunk, patch_count);
334       for (int64_t patch_index = patch_index_start;
335            patch_index < patch_index_end; ++patch_index) {
336         const int64_t batch = patch_index / (output_height * output_width);
337         const int64_t out_y = (patch_index / output_width) % output_height;
338         const int64_t out_x = patch_index % output_width;
339         const T1* input_batch_start =
340             input_data + (batch * input_height * input_width * input_depth);
341         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
342         const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
343         const int patch_index_within_chunk = patch_index % patches_per_chunk;
344         T1* im2col_patch_start =
345             im2col_buffer + (patch_index_within_chunk * filter_value_count);
346         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
347           const int in_y = in_y_origin + filter_y;
348           T1* im2col_row_start =
349               im2col_patch_start + (filter_y * filter_width * input_depth);
350           // If we're off the top or the bottom of the input, fill the
351           // whole row with zeroes.
352           if ((in_y < 0) || (in_y >= input_height)) {
353             T1* im2col_row_end =
354                 im2col_row_start + (filter_width * input_depth);
355             std::fill(im2col_row_start, im2col_row_end, T1(0));
356           } else {
357             // What we're doing here is trying to copy and fill the im2col
358             // buffer as efficiently as possible, using functions to set or
359             // duplicate values en masse. We know we don't have to worry about
360             // vertical edges because we dealt with that case above, so we
361             // just need to handle filters that overlap the left or right
362             // edges. Here's what that looks like:
363             //
364             // < left_zero_count > < center_copy_count > < right_zero_count >
365             // +------------------+---------------------+--------------------+
366             // |     (filter)     |       (image)       |      (filter)      |
367             // +------------------+---------------------+--------------------+
368             // in_x_origin        0                 input_width       in_x_end
369             //
370             // In reality it's unlikely that a filter patch will be wider
371             // than an input, but this shows all the edge cases.
372             // We use std::fill() to set the left and right sections to zeroes
373             // and std::copy() to copy over the input data for the center.
374             const int in_x_end = in_x_origin + filter_width;
375             const int left_zero_count = std::max(0, 0 - in_x_origin);
376             const int right_zero_count = std::max(0, in_x_end - input_width);
377             const int center_copy_count =
378                 filter_width - (left_zero_count + right_zero_count);
379             if (left_zero_count > 0) {
380               T1* im2col_left_start = im2col_row_start;
381               T1* im2col_left_end =
382                   im2col_left_start + (left_zero_count * input_depth);
383               std::fill(im2col_left_start, im2col_left_end, T1(0));
384             }
385             if (center_copy_count > 0) {
386               const T1* input_row_start =
387                   input_batch_start + (in_y * input_width * input_depth) +
388                   (std::max(0, in_x_origin) * input_depth);
389               const T1* input_row_end =
390                   input_row_start + (center_copy_count * input_depth);
391               T1* im2col_center_start =
392                   im2col_row_start + (left_zero_count * input_depth);
393               std::copy(input_row_start, input_row_end, im2col_center_start);
394             }
395             if (right_zero_count > 0) {
396               T1* im2col_right_start =
397                   im2col_row_start +
398                   ((left_zero_count + center_copy_count) * input_depth);
399               T1* im2col_right_end =
400                   im2col_right_start + (right_zero_count * input_depth);
401               std::fill(im2col_right_start, im2col_right_end, T1(0));
402             }
403           }
404         }
405       }
406       // Now we've assembled a set of image patches into a matrix, apply a
407       // GEMM matrix multiply of the patches as rows, times the filter
408       // weights in columns, to get partial results in the output matrix.
409       const int how_many_patches = patch_index_end - patch_index_start;
410       const int m = how_many_patches;
411       const int n = filter_count;
412       const int k = filter_value_count;
413       const int lda = filter_value_count;
414       const int ldb = filter_count;
415       const int ldc = filter_count;
416       T3* chunk_output_data = output_data + (patch_index_start * filter_count);
417       TGemmFunctor gemm_functor;
418       gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
419                    chunk_output_data, ldc);
420     }
421   }
422 };
423 
424 }  // namespace
425 
426 // This TensorFlow kernel class handles all of the IO and housekeeping for the
427 // functors that actually implement the underlying algorithm. To swap in
428 // different implementations of the main calculations, use a different
429 // TConvFunctor parameter when instantiating the template.
430 template <class T, class TConvFunctor>
431 class Conv2DUsingGemmOp : public BinaryOp<T> {
432  public:
Conv2DUsingGemmOp(OpKernelConstruction * context)433   explicit Conv2DUsingGemmOp(OpKernelConstruction* context)
434       : BinaryOp<T>(context) {
435     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
436     string data_format;
437     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
438     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
439                 errors::InvalidArgument("Invalid data format"));
440     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
441                 errors::InvalidArgument(
442                     "Data format not supported by this kernel", data_format));
443     OP_REQUIRES(context, strides_.size() == 4,
444                 errors::InvalidArgument("Sliding window strides field must "
445                                         "specify 4 dimensions"));
446     const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
447     const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
448     OP_REQUIRES(
449         context, stride_n == 1 && stride_c == 1,
450         errors::InvalidArgument("Current implementation does not yet support "
451                                 "strides in the batch and depth dimensions."));
452     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
453   }
454 
Compute(OpKernelContext * context)455   void Compute(OpKernelContext* context) override {
456     // Input tensor is of the following dimensions:
457     // [ batch, in_rows, in_cols, in_depth ]
458     const Tensor& input = context->input(0);
459 
460     // Input filter is of the following dimensions:
461     // [ filter_rows, filter_cols, in_depth, out_depth]
462     const Tensor& filter = context->input(1);
463 
464     // For 2D convolution, there should be 4 dimensions.
465     OP_REQUIRES(context, input.dims() == 4,
466                 errors::InvalidArgument("input must be 4-dimensional",
467                                         input.shape().DebugString()));
468     OP_REQUIRES(context, filter.dims() == 4,
469                 errors::InvalidArgument("filter must be 4-dimensional: ",
470                                         filter.shape().DebugString()));
471 
472     for (int i = 0; i < 3; i++) {
473       OP_REQUIRES(
474           context,
475           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
476           errors::InvalidArgument("filter too large"));
477     }
478 
479     // The last dimension for input is in_depth. It must be the same as the
480     // filter's in_depth.
481     const int64_t in_depth = GetTensorDim(input, data_format_, 'C');
482     OP_REQUIRES(context, in_depth == filter.dim_size(2),
483                 errors::InvalidArgument(
484                     "input and filter must have the same depth: ", in_depth,
485                     " vs ", filter.dim_size(2)));
486 
487     // The last dimension for filter is out_depth.
488     const int out_depth = static_cast<int>(filter.dim_size(3));
489 
490     // The second dimension for input is rows/height.
491     // The first dimension for filter is rows/height.
492     const int64_t input_rows_raw = GetTensorDim(input, data_format_, 'H');
493     OP_REQUIRES(
494         context,
495         FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
496         errors::InvalidArgument("Input rows too large"));
497     const int input_rows = static_cast<int>(input_rows_raw);
498     const int filter_rows = static_cast<int>(filter.dim_size(0));
499 
500     // The third dimension for input is columns/width.
501     // The second dimension for filter is columns/width.
502     const int64_t input_cols_raw = GetTensorDim(input, data_format_, 'W');
503     OP_REQUIRES(
504         context,
505         FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
506         errors::InvalidArgument("Input cols too large"));
507     const int input_cols = static_cast<int>(input_cols_raw);
508     const int filter_cols = static_cast<int>(filter.dim_size(1));
509 
510     // The first dimension for input is batch.
511     const int64_t batch_raw = GetTensorDim(input, data_format_, 'N');
512     OP_REQUIRES(context,
513                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
514                 errors::InvalidArgument("batch is too large"));
515     const int batch = static_cast<int>(batch_raw);
516 
517     // For now we take the stride from the second and third dimensions only (we
518     // do not support striding on the batch or depth dimension).
519     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
520     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
521 
522     int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
523     OP_REQUIRES_OK(context,
524                    GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
525                                          padding_, &out_rows, &pad_rows));
526     OP_REQUIRES_OK(context,
527                    GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
528                                          padding_, &out_cols, &pad_cols));
529     TensorShape out_shape =
530         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
531 
532     // Output tensor is of the following dimensions:
533     // [ in_batch, out_rows, out_cols, out_depth ]
534     Tensor* output = nullptr;
535     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
536 
537     VLOG(2) << "Conv2D: in_depth = " << in_depth
538             << ", input_cols = " << input_cols
539             << ", filter_cols = " << filter_cols
540             << ", input_rows = " << input_rows
541             << ", filter_rows = " << filter_rows
542             << ", stride_rows = " << stride_rows
543             << ", stride_cols = " << stride_cols
544             << ", out_depth = " << out_depth;
545 
546     // If there is nothing to compute, return.
547     if (out_shape.num_elements() == 0) {
548       return;
549     }
550     TConvFunctor conv_functor;
551     conv_functor(context, input.flat<T>().data(), batch, input_rows, input_cols,
552                  in_depth, filter.flat<T>().data(), filter_rows, filter_cols,
553                  out_depth, stride_rows, stride_cols, padding_,
554                  output->flat<T>().data(), out_rows, out_cols);
555   }
556 
557  private:
558   std::vector<int32> strides_;
559   Padding padding_;
560   TensorFormat data_format_;
561 
562   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp);
563 };
564 
565 #define REGISTER_CPU(T)                                         \
566   REGISTER_KERNEL_BUILDER(                                      \
567       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
568       Conv2DUsingGemmOp<                                        \
569           T, Im2ColConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
570 
571 // Only register this GEMM-based implementation of Conv2d if the compiler flags
572 // request the implementation explicitly, since otherwise it will clash with the
573 // default EigenTensor-based kernel.
574 #if defined(USE_GEMM_FOR_CONV)
575 TF_CALL_half(REGISTER_CPU);
576 TF_CALL_float(REGISTER_CPU);
577 #endif  // USE_GEMM_FOR_CONV
578 
579 }  // namespace tensorflow
580