• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20 
21 #if GOOGLE_CUDA
22 #define EIGEN_USE_GPU
23 #endif  // GOOGLE_CUDA
24 
25 #include "tensorflow/core/kernels/conv_ops.h"
26 
27 #include <string.h>
28 #include <map>
29 #include <vector>
30 
31 #include "tensorflow/core/framework/bounds_check.h"
32 #include "tensorflow/core/framework/numeric_op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/register_types.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/framework/tensor_slice.h"
38 #include "tensorflow/core/kernels/conv_2d.h"
39 #include "tensorflow/core/kernels/deep_conv2d.h"
40 #include "tensorflow/core/kernels/ops_util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/lib/strings/numbers.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/util/padding.h"
48 #include "tensorflow/core/util/tensor_format.h"
49 #include "tensorflow/core/util/use_cudnn.h"
50 
51 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
52 #include "tensorflow/core/kernels/xsmm_conv2d.h"
53 #endif
54 
55 #if GOOGLE_CUDA
56 #include "tensorflow/core/kernels/conv_ops_gpu.h"
57 #include "tensorflow/core/platform/stream_executor.h"
58 #include "tensorflow/core/protobuf/autotuning.pb.h"
59 #include "tensorflow/core/util/proto/proto_utils.h"
60 #endif  // GOOGLE_CUDA
61 
62 namespace tensorflow {
63 
64 typedef Eigen::ThreadPoolDevice CPUDevice;
65 typedef Eigen::GpuDevice GPUDevice;
66 
67 namespace {
68 template <typename Device, typename T>
69 struct LaunchGeneric {
operator ()tensorflow::__anonfb357e540111::LaunchGeneric70   void operator()(OpKernelContext* ctx, const Tensor& input,
71                   const Tensor& filter, int row_stride, int col_stride,
72                   int row_dilation, int col_dilation, const Padding& padding,
73                   Tensor* output, TensorFormat data_format) {
74     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
75                                          "supports NHWC tensor format for now.";
76     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
77         col_stride == 1) {
78       // For 1x1 kernel, the 2D convolution is reduced to matrix
79       // multiplication.
80       //
81       // TODO(vrv): We should be able to call SpatialConvolution
82       // and it will produce the same result, but doing so
83       // led to NaNs during training.  Using matmul instead for now.
84       int conv_width = 1;  // Width for the convolution step.
85       for (int i = 0; i < 3; ++i) {
86         conv_width *= output->dim_size(i);
87       }
88 
89       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
90       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
91       functor::MatMulConvFunctor<Device, T>()(
92           ctx->eigen_device<Device>(),
93           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
94           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
95           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
96           dim_pair);
97     } else if (filter.dim_size(0) == input.dim_size(1) &&
98                filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
99                col_dilation == 1 && padding == VALID) {
100       // If the input data and filter have the same height/width,
101       // the 2D convolution is reduced to matrix multiplication.
102       const int k =  // Length of reduction dimension.
103           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
104 
105       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
106       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
107       functor::MatMulConvFunctor<Device, T>()(
108           ctx->eigen_device<Device>(),
109           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
110           input.shaped<T, 2>({input.dim_size(0), k}),
111           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
112     } else {
113       functor::SpatialConvolution<Device, T>()(
114           ctx->eigen_device<Device>(), output->tensor<T, 4>(),
115           input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
116           row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
117     }
118   }
119 };
120 }  // namespace
121 
122 template <typename T>
123 struct LaunchConv2DOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DOp124   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
125                   const Tensor& input, const Tensor& filter, int row_dilation,
126                   int col_dilation, int row_stride, int col_stride,
127                   const Padding& padding,
128                   const std::vector<int64>& explicit_paddings, Tensor* output,
129                   TensorFormat data_format) {
130     if (data_format != FORMAT_NHWC) {
131       ctx->SetStatus(
132           errors::Unimplemented("Generic conv implementation only supports "
133                                 "NHWC tensor format for now."));
134       return;
135     }
136     // TODO(reedwm): Enable explicit padding on the CPU.
137     OP_REQUIRES(
138         ctx, padding != Padding::EXPLICIT,
139         errors::Unimplemented("Generic conv implementation does not support "
140                               "EXPLICIT padding yet."));
141     const int64 in_depth = GetTensorDim(input, data_format, 'C');
142     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
143                 errors::Unimplemented("Generic conv implementation does not "
144                                       "support grouped convolutions for now."));
145     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
146                                   row_dilation, col_dilation, padding, output,
147                                   data_format);
148   }
149 };
150 
151 template <typename Device, typename T>
152 class LaunchDeepConvOp {
153  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int,int,int,int,int,int,Tensor *,TensorFormat)154   static bool Run(OpKernelContext* ctx, const Tensor& input,
155                   const Tensor& filter, int batch, int input_rows,
156                   int input_cols, int in_depth, int filter_rows,
157                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
158                   int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
159                   int /*dilation_cols*/, int /*stride_rows*/,
160                   int /*stride_cols*/, Tensor* /*output*/,
161                   TensorFormat /*data_format*/) {
162     return false;
163   }
164 };
165 
166 // Conditionally launches DeepConv operation based on convolution parameters.
167 template <>
168 class LaunchDeepConvOp<CPUDevice, float> {
169  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)170   static bool Run(OpKernelContext* ctx, const Tensor& input,
171                   const Tensor& filter, int batch, int input_rows,
172                   int input_cols, int in_depth, int filter_rows,
173                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
174                   int out_cols, int out_depth, int dilation_rows,
175                   int dilation_cols, int stride_rows, int stride_cols,
176                   Tensor* output, TensorFormat data_format) {
177     if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
178         dilation_cols != 1 ||
179         !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
180                           in_depth, out_depth, out_rows, out_cols)) {
181       return false;
182     }
183 
184     Conv2DArgs args;
185     args.batch = batch;
186     args.in_rows = input_rows;
187     args.in_cols = input_cols;
188     args.in_depth = in_depth;
189     args.filter_rows = filter_rows;
190     args.filter_cols = filter_cols;
191     args.pad_rows = pad_rows;
192     args.pad_cols = pad_cols;
193     args.out_rows = out_rows;
194     args.out_cols = out_cols;
195     args.out_depth = out_depth;
196 
197     auto input_ptr = input.template flat<float>().data();
198     auto filter_ptr = filter.template flat<float>().data();
199     auto output_ptr = output->template flat<float>().data();
200 
201     functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
202                                             output_ptr);
203     return true;
204   }
205 };
206 
207 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
208 template <typename Device, typename T>
209 class LaunchXsmmConvOp {
210  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int stride_rows,int stride_cols,int dilation_rows,int dilation_cols,Tensor * output,TensorFormat data_format)211   static bool Run(OpKernelContext* ctx, const Tensor& input,
212                   const Tensor& filter, int batch, int input_rows,
213                   int input_cols, int in_depth, int filter_rows,
214                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
215                   int out_cols, int out_depth, int stride_rows, int stride_cols,
216                   int dilation_rows, int dilation_cols, Tensor* output,
217                   TensorFormat data_format) {
218     return false;
219   }
220 };
221 
222 template <>
223 class LaunchXsmmConvOp<CPUDevice, float> {
224  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)225   static bool Run(OpKernelContext* ctx, const Tensor& input,
226                   const Tensor& filter, int batch, int input_rows,
227                   int input_cols, int in_depth, int filter_rows,
228                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
229                   int out_cols, int out_depth, int dilation_rows,
230                   int dilation_cols, int stride_rows, int stride_cols,
231                   Tensor* output, TensorFormat data_format) {
232     auto num_threads =
233         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
234     // See libxsmm_dnn.h for this struct definition.
235     libxsmm_dnn_conv_desc desc;
236     desc.N = batch;
237     desc.C = in_depth;
238     desc.H = input_rows;
239     desc.W = input_cols;
240     desc.K = out_depth;
241     desc.R = filter_rows;
242     desc.S = filter_cols;
243     desc.u = stride_rows;
244     desc.v = stride_cols;
245     desc.pad_h = pad_rows;
246     desc.pad_w = pad_cols;
247     desc.pad_h_in = 0;
248     desc.pad_w_in = 0;
249     desc.pad_h_out = 0;
250     desc.pad_w_out = 0;
251     desc.threads = num_threads;
252     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
253     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
254     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
255     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
256     desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
257     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
258 
259     if (dilation_rows != 1 || dilation_cols != 1 ||
260         !CanUseXsmmConv2D(desc, data_format)) {
261       return false;
262     }
263 
264     auto input_ptr = input.template flat<float>().data();
265     auto filter_ptr = filter.template flat<float>().data();
266     auto output_ptr = output->template flat<float>().data();
267 
268     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
269         ctx, desc, input_ptr, filter_ptr, output_ptr);
270     return success;
271   }
272 };
273 #endif
274 
275 #define TF_REQUIRES(EXP, STATUS)                \
276   do {                                          \
277     if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
278   } while (false)
279 
InitConv2DParameters(const OpKernelConstruction * context,Conv2DParameters * params)280 Status InitConv2DParameters(const OpKernelConstruction* context,
281                             Conv2DParameters* params) {
282   TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
283   TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
284   TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
285   if (context->HasAttr("explicit_paddings")) {
286     TF_RETURN_IF_ERROR(
287         context->GetAttr("explicit_paddings", &params->explicit_paddings));
288   }
289   string data_format_string;
290   TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
291   TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
292               errors::InvalidArgument("Invalid data format"));
293 
294   const auto& strides = params->strides;
295   const auto& dilations = params->dilations;
296   const auto& data_format = params->data_format;
297 
298   TF_REQUIRES(dilations.size() == 4,
299               errors::InvalidArgument("Sliding window dilations field must "
300                                       "specify 4 dimensions"));
301   TF_REQUIRES(strides.size() == 4,
302               errors::InvalidArgument("Sliding window strides field must "
303                                       "specify 4 dimensions"));
304   const int64 stride_n = GetTensorDim(strides, data_format, 'N');
305   const int64 stride_c = GetTensorDim(strides, data_format, 'C');
306   const int64 stride_h = GetTensorDim(strides, data_format, 'H');
307   const int64 stride_w = GetTensorDim(strides, data_format, 'W');
308   TF_REQUIRES(
309       stride_n == 1 && stride_c == 1,
310       errors::InvalidArgument("Current implementation does not yet support "
311                               "strides in the batch and depth dimensions."));
312   TF_REQUIRES(stride_h > 0 && stride_w > 0,
313               errors::InvalidArgument(
314                   "Row and column strides should be larger than 0."));
315 
316   const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
317   const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
318   const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
319   const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
320   TF_REQUIRES(
321       dilation_n == 1 && dilation_c == 1,
322       errors::InvalidArgument("Current implementation does not yet support "
323                               "dilations in the batch and depth dimensions."));
324   TF_REQUIRES(
325       dilation_h > 0 && dilation_w > 0,
326       errors::InvalidArgument("Dilated rates should be larger than 0."));
327 
328   TF_RETURN_IF_ERROR(CheckValidPadding(params->padding,
329                                        params->explicit_paddings,
330                                        /*num_dims=*/4, data_format));
331 
332   return Status::OK();
333 }
334 
ComputeConv2DDimension(const Conv2DParameters & params,const Tensor & input,const Tensor & filter,Conv2DDimensions * dimensions)335 Status ComputeConv2DDimension(const Conv2DParameters& params,
336                               const Tensor& input, const Tensor& filter,
337                               Conv2DDimensions* dimensions) {
338   // Check that 2D convolution input and filter have exactly 4 dimensions.
339   TF_REQUIRES(input.dims() == 4,
340               errors::InvalidArgument("input must be 4-dimensional",
341                                       input.shape().DebugString()));
342   TF_REQUIRES(filter.dims() == 4,
343               errors::InvalidArgument("filter must be 4-dimensional: ",
344                                       filter.shape().DebugString()));
345   for (int i = 0; i < 3; i++) {
346     TF_REQUIRES(
347         FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
348         errors::InvalidArgument("filter too large"));
349   }
350 
351   // The last dimension for input is in_depth. Check that it is the same as the
352   // filter's in_depth or it is evenly divisible by filter's in_depth.
353   const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
354   const int64 patch_depth_raw = filter.dim_size(2);
355   TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
356               errors::InvalidArgument("Input depth too large"));
357   TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
358               errors::InvalidArgument("Patch depth too large"));
359   const int in_depth = static_cast<int>(in_depth_raw);
360   const int patch_depth = static_cast<int>(patch_depth_raw);
361   TF_REQUIRES(in_depth % patch_depth == 0,
362               errors::InvalidArgument(
363                   "input depth must be evenly divisible by filter depth: ",
364                   in_depth, " vs ", patch_depth));
365 
366   // The last dimension for filter is out_depth.
367   const int out_depth = static_cast<int>(filter.dim_size(3));
368 
369   // The second dimension for input is rows/height.
370   // The first dimension for filter is rows/height.
371   const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
372   TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
373               errors::InvalidArgument("Input rows too large"));
374   const int input_rows = static_cast<int>(input_rows_raw);
375   const int filter_rows = static_cast<int>(filter.dim_size(0));
376 
377   // The third dimension for input is columns/width.
378   // The second dimension for filter is columns/width.
379   const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
380   TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
381               errors::InvalidArgument("Input cols too large"));
382   const int input_cols = static_cast<int>(input_cols_raw);
383   const int filter_cols = static_cast<int>(filter.dim_size(1));
384 
385   // The first dimension for input is batch.
386   const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
387   TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
388               errors::InvalidArgument("batch is too large"));
389   const int batch = static_cast<int>(batch_raw);
390 
391   // Take the stride and dilation from the second and third dimensions only (we
392   // do not support striding or dilation on the batch or depth dimension).
393   const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
394   const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
395   const int dilation_rows =
396       GetTensorDim(params.dilations, params.data_format, 'H');
397   const int dilation_cols =
398       GetTensorDim(params.dilations, params.data_format, 'W');
399 
400   int64 pad_rows_before, pad_rows_after, pad_cols_before, pad_cols_after;
401   if (params.padding == Padding::EXPLICIT) {
402     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'H',
403                              &pad_rows_before, &pad_rows_after);
404     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'W',
405                              &pad_cols_before, &pad_cols_after);
406   }
407 
408   // Compute windowed output sizes for rows and columns.
409   int64 out_rows = 0, out_cols = 0;
410   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
411       input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
412       &out_rows, &pad_rows_before, &pad_rows_after));
413   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
414       input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
415       &out_cols, &pad_cols_before, &pad_cols_after));
416 
417   dimensions->batch = batch;
418   dimensions->input_rows = input_rows;
419   dimensions->input_cols = input_cols;
420   dimensions->in_depth = in_depth;
421   dimensions->filter_rows = filter_rows;
422   dimensions->filter_cols = filter_cols;
423   dimensions->patch_depth = patch_depth;
424   dimensions->out_depth = out_depth;
425   dimensions->stride_rows = stride_rows;
426   dimensions->stride_cols = stride_cols;
427   dimensions->dilation_rows = dilation_rows;
428   dimensions->dilation_cols = dilation_cols;
429   dimensions->out_rows = out_rows;
430   dimensions->out_cols = out_cols;
431   dimensions->pad_rows_before = pad_rows_before;
432   dimensions->pad_rows_after = pad_rows_after;
433   dimensions->pad_cols_before = pad_cols_before;
434   dimensions->pad_cols_after = pad_cols_after;
435 
436   return Status::OK();
437 }
438 
439 #undef TF_REQUIRES
440 
441 template <typename Device, typename T>
442 class Conv2DOp : public BinaryOp<T> {
443  public:
Conv2DOp(OpKernelConstruction * context)444   explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
445     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
446 
447     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
448     use_cudnn_ &= CanUseCudnn();
449     cudnn_use_autotune_ = CudnnUseAutotune();
450   }
451 
Compute(OpKernelContext * context)452   void Compute(OpKernelContext* context) override {
453     // Input tensor is of the following dimensions:
454     // [ batch, in_rows, in_cols, in_depth ]
455     const Tensor& input = context->input(0);
456 
457     // Input filter is of the following dimensions:
458     // [ filter_rows, filter_cols, in_depth, out_depth]
459     const Tensor& filter = context->input(1);
460 
461     Conv2DDimensions dimensions;
462     OP_REQUIRES_OK(context,
463                    ComputeConv2DDimension(params_, input, filter, &dimensions));
464 
465     TensorShape out_shape = ShapeFromFormat(
466         params_.data_format, dimensions.batch, dimensions.out_rows,
467         dimensions.out_cols, dimensions.out_depth);
468 
469     // Output tensor is of the following dimensions:
470     // [ in_batch, out_rows, out_cols, out_depth ]
471     Tensor* output = nullptr;
472     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
473 
474     VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
475             << ", patch_depth = " << dimensions.patch_depth
476             << ", input_cols = " << dimensions.input_cols
477             << ", filter_cols = " << dimensions.filter_cols
478             << ", input_rows = " << dimensions.input_rows
479             << ", filter_rows = " << dimensions.filter_rows
480             << ", stride_rows = " << dimensions.stride_rows
481             << ", stride_cols = " << dimensions.stride_cols
482             << ", dilation_rows = " << dimensions.dilation_rows
483             << ", dilation_cols = " << dimensions.dilation_cols
484             << ", out_depth = " << dimensions.out_depth;
485 
486     // If there is nothing to compute, return.
487     if (out_shape.num_elements() == 0) {
488       return;
489     }
490 
491 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
492     if (params_.padding != EXPLICIT &&
493         LaunchXsmmConvOp<Device, T>::Run(
494             context, input, filter, dimensions.batch, dimensions.input_rows,
495             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
496             dimensions.filter_cols, dimensions.pad_rows_before,
497             dimensions.pad_cols_before, dimensions.out_rows,
498             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
499             dimensions.dilation_cols, dimensions.stride_rows,
500             dimensions.stride_cols, output, params_.data_format)) {
501       return;
502     }
503 #endif
504 
505     if (params_.padding != EXPLICIT &&
506         LaunchDeepConvOp<Device, T>::Run(
507             context, input, filter, dimensions.batch, dimensions.input_rows,
508             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
509             dimensions.filter_cols, dimensions.pad_rows_before,
510             dimensions.pad_cols_before, dimensions.out_rows,
511             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
512             dimensions.dilation_cols, dimensions.stride_rows,
513             dimensions.stride_cols, output, params_.data_format)) {
514       return;
515     }
516 
517     launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
518               dimensions.dilation_rows, dimensions.dilation_cols,
519               dimensions.stride_rows, dimensions.stride_cols, params_.padding,
520               params_.explicit_paddings, output, params_.data_format);
521   }
522 
523  private:
524   Conv2DParameters params_;
525   bool use_cudnn_;
526   bool cudnn_use_autotune_;
527 
528   LaunchConv2DOp<Device, T> launcher_;
529 
530   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
531 };
532 
533 #define REGISTER_CPU(T)                                         \
534   REGISTER_KERNEL_BUILDER(                                      \
535       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
536       Conv2DOp<CPUDevice, T>);
537 
538 // If we're using the alternative GEMM-based implementation of Conv2D for the
539 // CPU implementation, don't register this EigenTensor-based version.
540 #if !defined(USE_GEMM_FOR_CONV)
541 TF_CALL_half(REGISTER_CPU);
542 TF_CALL_float(REGISTER_CPU);
543 TF_CALL_double(REGISTER_CPU);
544 #endif  // USE_GEMM_FOR_CONV
545 
546 // To be used inside depthwise_conv_op.cc.
547 template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
548 template struct LaunchConv2DOp<CPUDevice, float>;
549 template struct LaunchConv2DOp<CPUDevice, double>;
550 
551 #if GOOGLE_CUDA
GetDnnWorkspaceLimit(const string & envvar_in_mb,int64 default_value_in_bytes)552 int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
553                            int64 default_value_in_bytes) {
554   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
555   if (workspace_limit_in_mb_str != nullptr &&
556       strcmp(workspace_limit_in_mb_str, "") != 0) {
557     int64 scratch_limit_in_mb = -1;
558     if (strings::safe_strto64(workspace_limit_in_mb_str,
559                               &scratch_limit_in_mb)) {
560       return scratch_limit_in_mb * (1 << 20);
561     } else {
562       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
563                    << workspace_limit_in_mb_str;
564     }
565   }
566   return default_value_in_bytes;
567 }
568 
569 // A dummy type to group forward convolution autotune results together.
570 struct ConvAutoTuneGroup {
nametensorflow::ConvAutoTuneGroup571   static string name() { return "Conv"; }
572 };
573 typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
574                           se::dnn::AlgorithmConfig>
575     AutoTuneConv;
576 
577 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & input_param,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * output,TensorFormat data_format)578 void LaunchConv2DOp<GPUDevice, T>::operator()(
579     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
580     const Tensor& input_param, const Tensor& filter, int row_dilation,
581     int col_dilation, int row_stride, int col_stride, const Padding& padding,
582     const std::vector<int64>& explicit_paddings, Tensor* output,
583     TensorFormat data_format) {
584   using se::dnn::AlgorithmConfig;
585   using se::dnn::AlgorithmDesc;
586   using se::dnn::ProfileResult;
587   auto* stream = ctx->op_device_context()->stream();
588   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
589 
590   if (!use_cudnn) {
591     ctx->SetStatus(
592         errors::Unimplemented("Conv2D for GPU is not currently supported "
593                               "without cudnn"));
594     return;
595   }
596 
597   Tensor input = input_param;
598   const int64 in_batch = GetTensorDim(input, data_format, 'N');
599   int64 in_rows = GetTensorDim(input, data_format, 'H');
600   int64 in_cols = GetTensorDim(input, data_format, 'W');
601   const int64 in_depths = GetTensorDim(input, data_format, 'C');
602   const int64 patch_rows = filter.dim_size(0);
603   const int64 patch_cols = filter.dim_size(1);
604   const int64 patch_depths = filter.dim_size(2);
605 
606   // If the filter in-depth (patch_depths) is 1 and smaller than the input
607   // depth, it's a depthwise convolution. More generally, if the filter in-depth
608   // divides but is smaller than the input depth, it is a grouped convolution.
609   bool is_grouped_convolution = patch_depths != in_depths;
610   if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
611       row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
612       col_stride == 1 && data_format == FORMAT_NHWC &&
613       (padding == VALID || padding == SAME)) {
614     // 1x1 filter, so call cublas directly.
615     const uint64 m = in_batch * in_rows * in_cols;
616     const uint64 k = patch_depths;
617     const uint64 n = filter.dim_size(3);
618 
619     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
620                                 input.template flat<T>().size());
621     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
622                                 filter.template flat<T>().size());
623     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
624                                 output->template flat<T>().size());
625 
626     auto no_transpose = se::blas::Transpose::kNoTranspose;
627     bool blas_launch_status =
628         stream
629             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
630                            a_ptr, k, 0.0f, &c_ptr, n)
631             .ok();
632     if (!blas_launch_status) {
633       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
634                                       ", n=", n, ", k=", k));
635     }
636     return;
637   } else if (patch_rows == in_rows && patch_cols == in_cols &&
638              !is_grouped_convolution && row_dilation == 1 &&
639              col_dilation == 1 && padding == VALID &&
640              data_format == FORMAT_NHWC) {
641     // The input data and filter have the same height/width, so call cublas
642     // directly.
643     const uint64 m = in_batch;
644     const uint64 k = patch_rows * patch_cols * patch_depths;
645     const uint64 n = filter.dim_size(3);
646 
647     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
648                                 input.template flat<T>().size());
649     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
650                                 filter.template flat<T>().size());
651     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
652                                 output->template flat<T>().size());
653 
654     auto no_transpose = se::blas::Transpose::kNoTranspose;
655     bool blas_launch_status =
656         stream
657             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
658                            a_ptr, k, 0.0f, &c_ptr, n)
659             .ok();
660     if (!blas_launch_status) {
661       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
662                                       ", n=", n, ", k=", k));
663     }
664     return;
665   }
666 
667   const int64 out_batch = GetTensorDim(*output, data_format, 'N');
668   const int64 out_rows = GetTensorDim(*output, data_format, 'H');
669   const int64 out_cols = GetTensorDim(*output, data_format, 'W');
670   const int64 out_depths = GetTensorDim(*output, data_format, 'C');
671   int64 padding_top = -1, padding_bottom = -1;
672   int64 padding_left = -1, padding_right = -1;
673   if (padding == EXPLICIT) {
674     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
675                              &padding_bottom);
676     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
677                              &padding_right);
678   }
679   int64 out_rows_check, out_cols_check;
680   Status status = GetWindowedOutputSizeVerboseV2(
681       in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check,
682       &padding_top, &padding_bottom);
683   // The status is guaranteed to be OK because we checked the output and padding
684   // was valid earlier.
685   TF_CHECK_OK(status);
686   DCHECK_EQ(out_rows, out_rows_check);
687   status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation,
688                                           col_stride, padding, &out_cols_check,
689                                           &padding_left, &padding_right);
690   TF_CHECK_OK(status);
691   DCHECK_EQ(out_cols, out_cols_check);
692 
693   const int64 common_padding_rows = std::min(padding_top, padding_bottom);
694   const int64 common_padding_cols = std::min(padding_left, padding_right);
695   if (padding_top != padding_bottom || padding_left != padding_right) {
696     // cuDNN only supports padding the same amount on the left and right sides,
697     // and on the top and bottom sides. So we manually create a new padded
698     // input tensor such that we can pass it to cuDNN.
699 
700     // TODO(reedwm): In some cases, we can avoid an allocation even if the two
701     // padding sides are different. For example, if the input is 2x2, the filter
702     // is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the result is
703     // equivalent to as if the padding is (1, 1, 1, 1). Changing the padding in
704     // such a way would allow us to avoid the allocation.
705     Tensor transformed_input;
706     const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
707     const int64 padding_cols_diff = std::abs(padding_right - padding_left);
708     const int64 new_in_rows = in_rows + padding_rows_diff;
709     const int64 new_in_cols = in_cols + padding_cols_diff;
710     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
711                             DataTypeToEnum<T>::value,
712                             ShapeFromFormat(data_format, in_batch, new_in_rows,
713                                             new_in_cols, in_depths),
714                             &transformed_input));
715 
716     const int64 input_pad_top = padding_top - common_padding_rows;
717     const int64 input_pad_bottom = padding_bottom - common_padding_rows;
718     const int64 input_pad_left = padding_left - common_padding_cols;
719     const int64 input_pad_right = padding_right - common_padding_cols;
720     bool in_bounds =
721         FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
722         FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
723         FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
724         FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
725     if (!in_bounds) {
726       ctx->SetStatus(errors::InvalidArgument("Padding is too large."));
727       return;
728     }
729     functor::PadInput<GPUDevice, T, int, 4>()(
730         ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
731         {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
732         {{static_cast<int>(input_pad_bottom),
733           static_cast<int>(input_pad_right)}},
734         To32Bit(transformed_input.tensor<T, 4>()), data_format);
735 
736     input = transformed_input;
737     in_rows = new_in_rows;
738     in_cols = new_in_cols;
739   }
740 
741   if (data_format == FORMAT_NHWC) {
742     // Convert the input tensor from NHWC to NCHW.
743     TensorShape nchw_shape =
744         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
745     if (in_depths > 1) {
746       Tensor transformed_input;
747       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
748                                              nchw_shape, &transformed_input));
749       functor::NHWCToNCHW<GPUDevice, T, 4>()(
750           ctx->eigen_device<GPUDevice>(),
751           const_cast<const Tensor&>(input).tensor<T, 4>(),
752           transformed_input.tensor<T, 4>());
753       input = transformed_input;
754     } else {
755       // If depth <= 1, then just reshape.
756       CHECK(input.CopyFrom(input, nchw_shape));
757     }
758   }
759 
760   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
761       << "Negative row or col paddings: (" << common_padding_rows << ", "
762       << common_padding_cols << ")";
763   se::dnn::BatchDescriptor input_desc;
764   input_desc.set_count(in_batch)
765       .set_feature_map_count(in_depths)
766       .set_height(in_rows)
767       .set_width(in_cols)
768       .set_layout(se::dnn::DataLayout::kBatchDepthYX);
769   se::dnn::BatchDescriptor output_desc;
770   output_desc.set_count(out_batch)
771       .set_height(out_rows)
772       .set_width(out_cols)
773       .set_feature_map_count(out_depths)
774       .set_layout(se::dnn::DataLayout::kBatchDepthYX);
775   se::dnn::FilterDescriptor filter_desc;
776   filter_desc.set_input_filter_height(patch_rows)
777       .set_input_filter_width(patch_cols)
778       .set_input_feature_map_count(patch_depths)
779       .set_output_feature_map_count(filter.dim_size(3));
780   se::dnn::ConvolutionDescriptor conv_desc;
781   conv_desc.set_vertical_dilation_rate(row_dilation)
782       .set_horizontal_dilation_rate(col_dilation)
783       .set_vertical_filter_stride(row_stride)
784       .set_horizontal_filter_stride(col_stride)
785       .set_zero_padding_height(common_padding_rows)
786       .set_zero_padding_width(common_padding_cols)
787       .set_group_count(in_depths / patch_depths);
788 
789   Tensor transformed_filter;
790   OP_REQUIRES_OK(ctx, ctx->allocate_temp(
791                           DataTypeToEnum<T>::value,
792                           TensorShape({filter.dim_size(3), filter.dim_size(2),
793                                        filter.dim_size(0), filter.dim_size(1)}),
794                           &transformed_filter));
795   functor::TransformFilter<GPUDevice, T, int, 4>()(
796       ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
797       To32Bit(filter.tensor<T, 4>()),
798       To32Bit(transformed_filter.tensor<T, 4>()));
799 
800   Tensor transformed_output;
801   if (data_format == FORMAT_NHWC) {
802     // Only allocate temporary memory when a layout transformation is needed.
803     OP_REQUIRES_OK(
804         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
805                                 ShapeFromFormat(FORMAT_NCHW, out_batch,
806                                                 out_rows, out_cols, out_depths),
807                                 &transformed_output));
808   } else {
809     transformed_output = *output;
810   }
811 
812   auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
813                                   input.template flat<T>().size());
814   auto filter_ptr =
815       AsDeviceMemory(transformed_filter.template flat<T>().data(),
816                      transformed_filter.template flat<T>().size());
817   auto output_ptr =
818       AsDeviceMemory(transformed_output.template flat<T>().data(),
819                      transformed_output.template flat<T>().size());
820 
821   static int64 ConvolveScratchSize = GetDnnWorkspaceLimit(
822       // default value is in bytes despite the name of the environment variable
823       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
824   );
825 
826   int device_id = stream->parent()->device_ordinal();
827   DataType dtype = input.dtype();
828   ConvParameters conv_parameters = {
829       in_batch,                 // batch
830       in_depths,                // in_depths
831       {{in_rows,                // in_rows
832         in_cols}},              // in_cols
833       FORMAT_NCHW,              // compute_data_format
834       out_depths,               // out_depths
835       {{patch_rows,             // filter_rows
836         patch_cols,             // filter_cols
837         patch_depths}},         // filter_depths
838       {{row_dilation,           // dilation_rows
839         col_dilation}},         // dilation_cols
840       {{row_stride,             // stride_rows
841         col_stride}},           // stride_cols
842       {{common_padding_rows,    // padding_rows
843         common_padding_cols}},  // padding_cols
844       dtype,                    // tensor datatype
845       device_id,                // device_id
846   };
847   AlgorithmConfig algorithm_config;
848   if (cudnn_use_autotune &&
849       !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
850     std::vector<AlgorithmDesc> algorithms;
851     OP_REQUIRES(
852         ctx,
853         stream->parent()->GetConvolveAlgorithms(
854             conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
855                 stream->parent()),
856             &algorithms),
857         errors::Unknown("Failed to get convolution algorithm. This is probably "
858                         "because cuDNN failed to initialize, so try looking to "
859                         "see if a warning log message was printed above."));
860     std::vector<tensorflow::AutotuneResult> results;
861     for (auto profile_algorithm : algorithms) {
862       // TODO(zhengxq): profile each algorithm multiple times to better
863       // accuracy.
864       DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
865       ProfileResult profile_result;
866       bool cudnn_launch_status =
867           stream
868               ->ThenConvolveWithAlgorithm(
869                   input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
870                   output_desc, &output_ptr, &scratch_allocator,
871                   AlgorithmConfig(profile_algorithm), &profile_result)
872               .ok();
873       if (cudnn_launch_status) {
874         if (profile_result.is_valid()) {
875           results.emplace_back();
876           auto& result = results.back();
877           result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
878           result.mutable_conv()->set_tensor_ops_enabled(
879               profile_algorithm.tensor_ops_enabled());
880           result.mutable_success()->set_scratch_bytes(
881               scratch_allocator.TotalByteSize());
882           *result.mutable_success()->mutable_run_time() =
883               proto_utils::ToDurationProto(
884                   absl::Milliseconds(profile_result.elapsed_time_in_ms()));
885         }
886       }
887     }
888     LogConvAutotuneResults(ctx->op_kernel().def(), input, transformed_filter,
889                            transformed_output, stream->parent(), results);
890     OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
891     AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
892   }
893 
894   DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
895   bool cudnn_launch_status =
896       stream
897           ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
898                                       filter_ptr, conv_desc, output_desc,
899                                       &output_ptr, &scratch_allocator,
900                                       algorithm_config, nullptr)
901           .ok();
902 
903   if (!cudnn_launch_status) {
904     ctx->SetStatus(errors::Internal(
905         "cuDNN launch failure : input shape(", input.shape().DebugString(),
906         ") filter shape(", filter.shape().DebugString(), ")"));
907   }
908 
909   // Convert the output tensor back from NCHW to NHWC.
910   if (data_format == FORMAT_NHWC) {
911     functor::NCHWToNHWC<GPUDevice, T, 4>()(
912         ctx->eigen_device<GPUDevice>(),
913         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
914         output->tensor<T, 4>());
915   }
916 }
917 
918 // Forward declarations of the functor specializations for GPU.
919 namespace functor {
920 #define DECLARE_GPU_SPEC(T)                                                 \
921   template <>                                                               \
922   void SpatialConvolution<GPUDevice, T>::operator()(                        \
923       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
924       typename TTypes<T, 4>::ConstTensor input,                             \
925       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
926       int col_stride, int row_dilation, int col_dilation,                   \
927       const Eigen::PaddingType& padding,                                    \
928       const Eigen::NoOpOutputKernel& output_kernel);                        \
929   extern template struct SpatialConvolution<GPUDevice, T>;                  \
930   template <>                                                               \
931   void MatMulConvFunctor<GPUDevice, T>::operator()(                         \
932       const GPUDevice& d, typename TTypes<T, 2>::Tensor out,                \
933       typename TTypes<T, 2>::ConstTensor in0,                               \
934       typename TTypes<T, 2>::ConstTensor in1,                               \
935       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, \
936       const Eigen::NoOpOutputKernel& output_kernel);                        \
937   extern template struct MatMulConvFunctor<GPUDevice, T>;                   \
938   template <>                                                               \
939   void TransformFilter<GPUDevice, T, int, 4>::operator()(                   \
940       const GPUDevice& d, FilterTensorFormat dst_filter_format,             \
941       typename TTypes<T, 4, int>::ConstTensor in,                           \
942       typename TTypes<T, 4, int>::Tensor out);                              \
943   extern template struct TransformFilter<GPUDevice, T, int, 4>;             \
944   template <>                                                               \
945   void PadInput<GPUDevice, T, int, 4>::operator()(                          \
946       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,       \
947       const std::array<int, 2>& padding_left,                               \
948       const std::array<int, 2>& padding_right,                              \
949       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format);    \
950   extern template struct PadInput<GPUDevice, T, int, 4>
951 
952 DECLARE_GPU_SPEC(float);
953 DECLARE_GPU_SPEC(Eigen::half);
954 DECLARE_GPU_SPEC(double);
955 #undef DECLARE_GPU_SPEC
956 }  // namespace functor
957 
958 // Registration of the GPU implementations.
959 REGISTER_KERNEL_BUILDER(
960     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
961     Conv2DOp<GPUDevice, Eigen::half>);
962 REGISTER_KERNEL_BUILDER(
963     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
964     Conv2DOp<GPUDevice, float>);
965 REGISTER_KERNEL_BUILDER(
966     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
967     Conv2DOp<GPUDevice, double>);
968 
969 // To be used inside depthwise_conv_op.cc.
970 template struct LaunchConv2DOp<GPUDevice, float>;
971 template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
972 template struct LaunchConv2DOp<GPUDevice, double>;
973 
974 #endif  // GOOGLE_CUDA
975 
976 }  // namespace tensorflow
977