• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/conv_2d.h"
20 #include "tensorflow/core/kernels/conv_3d.h"
21 
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_slice.h"
28 #include "tensorflow/core/kernels/conv_ops_gpu.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/util/padding.h"
32 #include "tensorflow/core/util/tensor_format.h"
33 #include "tensorflow/core/util/use_cudnn.h"
34 
35 #if GOOGLE_CUDA
36 #include "tensorflow/core/platform/stream_executor.h"
37 #include "tensorflow/core/protobuf/autotuning.pb.h"
38 #include "tensorflow/core/util/proto/proto_utils.h"
39 using stream_executor::dnn::DimIndex;
40 #endif
41 
42 namespace tensorflow {
43 
44 typedef Eigen::ThreadPoolDevice CPUDevice;
45 typedef Eigen::GpuDevice GPUDevice;
46 
47 template <typename Device, typename T>
48 struct LaunchConvOp;
49 
50 template <typename T>
51 struct LaunchConvOp<CPUDevice, T> {
launchtensorflow::LaunchConvOp52   static void launch(OpKernelContext* context, bool cudnn_use_autotune,
53                      const Tensor& input, const Tensor& filter,
54                      const std::array<int64, 3>& dilations,
55                      const std::array<int64, 3>& strides, const Padding padding,
56                      TensorFormat data_format, Tensor* output) {
57     OP_REQUIRES(context, data_format == FORMAT_NHWC,
58                 errors::InvalidArgument("CPU implementation of Conv3D "
59                                         "currently only supports the NHWC "
60                                         "tensor format."));
61     OP_REQUIRES(context,
62                 dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
63                 errors::InvalidArgument("CPU implementation of Conv3D "
64                                         "currently only supports dilated rates "
65                                         "of 1."));
66     functor::CuboidConvolution<CPUDevice, T>()(
67         context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
68         input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
69         strides[0], BrainPadding2EigenPadding(padding));
70   }
71 };
72 
73 template <typename Device, typename T>
74 class Conv3DOp : public BinaryOp<T> {
75  public:
Conv3DOp(OpKernelConstruction * context)76   explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
77     string data_format;
78     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
79     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
80                 errors::InvalidArgument("Invalid data format"));
81     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
82     OP_REQUIRES(context, stride_.size() == 5,
83                 errors::InvalidArgument("Sliding window strides field must "
84                                         "specify 5 dimensions"));
85     OP_REQUIRES(
86         context,
87         (GetTensorDim(stride_, data_format_, 'N') == 1 &&
88          GetTensorDim(stride_, data_format_, 'C') == 1),
89         errors::InvalidArgument("Current implementation does not yet support "
90                                 "strides in the batch and depth dimensions."));
91     OP_REQUIRES(
92         context,
93         (GetTensorDim(stride_, data_format_, '0') > 0 &&
94          GetTensorDim(stride_, data_format_, '1') > 0 &&
95          GetTensorDim(stride_, data_format_, '2') > 0),
96         errors::InvalidArgument("Spatial strides should be larger than 0."));
97     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
98     OP_REQUIRES(context, dilation_.size() == 5,
99                 errors::InvalidArgument("Dilation rates field must "
100                                         "specify 5 dimensions"));
101     OP_REQUIRES(context,
102                 (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
103                  GetTensorDim(dilation_, data_format_, 'C') == 1),
104                 errors::InvalidArgument(
105                     "Current implementation does not yet support "
106                     "dilation rates in the batch and depth dimensions."));
107     OP_REQUIRES(
108         context,
109         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
110          GetTensorDim(dilation_, data_format_, '1') > 0 &&
111          GetTensorDim(dilation_, data_format_, '2') > 0),
112         errors::InvalidArgument("Dilated rates should be larger than 0."));
113     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
114     cudnn_use_autotune_ = CudnnUseAutotune();
115   }
116 
Compute(OpKernelContext * context)117   void Compute(OpKernelContext* context) override {
118     // Input tensor is of the following dimensions:
119     // [ batch, in_z, in_y, in_x, in_channels ]
120     const Tensor& input = context->input(0);
121 
122     // Input filter is of the following dimensions:
123     // [ filter_z, filter_y, filter_x, in_channels, out_channels]
124     const Tensor& filter = context->input(1);
125 
126     // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
127     // kept consistent between input/filter/output.
128     OP_REQUIRES(context, input.dims() == 5,
129                 errors::InvalidArgument("input must be 5-dimensional"));
130     OP_REQUIRES(context, filter.dims() == 5,
131                 errors::InvalidArgument("filter must be 5-dimensional"));
132 
133     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
134     const int64 in_batch = GetTensorDim(input, data_format_, 'N');
135 
136     const int64 out_depth = filter.dim_size(4);
137     OP_REQUIRES(
138         context, in_depth == filter.dim_size(3),
139         errors::InvalidArgument("input and filter must have the same depth"));
140 
141     // Dimension order for these arrays is: z, y, x.
142     std::array<int64, 3> input_size = {
143         {GetTensorDim(input, data_format_, '0'),
144          GetTensorDim(input, data_format_, '1'),
145          GetTensorDim(input, data_format_, '2')}};
146     std::array<int64, 3> filter_size = {
147         {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
148     std::array<int64, 3> dilations = {
149         {GetTensorDim(dilation_, data_format_, '0'),
150          GetTensorDim(dilation_, data_format_, '1'),
151          GetTensorDim(dilation_, data_format_, '2')}};
152     std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
153                                      GetTensorDim(stride_, data_format_, '1'),
154                                      GetTensorDim(stride_, data_format_, '2')}};
155     std::array<int64, 3> out, padding;
156 
157     OP_REQUIRES_OK(
158         context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
159                                    padding_, &out, &padding));
160     TensorShape out_shape = ShapeFromFormat(
161         data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
162     Tensor* output;
163     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
164 
165     // Return early if nothing to do.
166     if (out_shape.num_elements() == 0) return;
167 
168     LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
169                                     dilations, strides, padding_, data_format_,
170                                     output);
171   }
172 
173  private:
174   std::vector<int32> dilation_;
175   std::vector<int32> stride_;
176   Padding padding_;
177   TensorFormat data_format_;
178   bool cudnn_use_autotune_;
179 };
180 
181 #define REGISTER_CPU_KERNEL(T)                                  \
182   REGISTER_KERNEL_BUILDER(                                      \
183       Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
184       Conv3DOp<CPUDevice, T>);
185 TF_CALL_half(REGISTER_CPU_KERNEL);
186 TF_CALL_float(REGISTER_CPU_KERNEL);
187 TF_CALL_double(REGISTER_CPU_KERNEL);
188 #undef REGISTER_CPU_KERNEL
189 
190 #if GOOGLE_CUDA
191 
192 // A dummy type to group forward convolution autotune results together.
193 struct Conv3dAutoTuneGroup {
nametensorflow::Conv3dAutoTuneGroup194   static string name() { return "Conv3d"; }
195 };
196 typedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters,
197                           se::dnn::AlgorithmConfig>
198     AutoTuneConv3d;
199 
200 // TODO(mjanusz): Share logic with 2d implementation as much as possible.
201 template <typename T>
202 struct LaunchConvOp<GPUDevice, T> {
launchtensorflow::LaunchConvOp203   static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
204                      const Tensor& input_param, const Tensor& filter,
205                      const std::array<int64, 3>& dilations,
206                      const std::array<int64, 3>& strides, const Padding padding,
207                      TensorFormat data_format, Tensor* output) {
208     auto* stream = ctx->op_device_context()->stream();
209     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
210 
211     Tensor input = input_param;
212 
213     const int64 in_batch = GetTensorDim(input, data_format, 'N');
214     int64 in_planes = GetTensorDim(input, data_format, '0');
215     int64 in_rows = GetTensorDim(input, data_format, '1');
216     int64 in_cols = GetTensorDim(input, data_format, '2');
217     const int64 in_depth = GetTensorDim(input, data_format, 'C');
218 
219     const int64 filter_planes = filter.dim_size(0);
220     const int64 filter_rows = filter.dim_size(1);
221     const int64 filter_cols = filter.dim_size(2);
222     const int64 out_depth = filter.dim_size(4);
223 
224     int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
225     int64 out_planes = GetTensorDim(*output, data_format, '0');
226     int64 out_rows = GetTensorDim(*output, data_format, '1');
227     int64 out_cols = GetTensorDim(*output, data_format, '2');
228 
229     if (padding == Padding::SAME) {
230       pad_planes = std::max<int64>(
231           0, (out_planes - 1) * strides[0] + filter_planes - in_planes);
232       pad_rows = std::max<int64>(
233           0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
234       pad_cols = std::max<int64>(
235           0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
236     }
237 
238     // NOTE: This only works in NHWC.
239     if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
240         dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1 &&
241         strides[0] == 1 && strides[1] == 1 && strides[2] == 1 &&
242         data_format == FORMAT_NHWC) {
243       // 1x1 filter, so call cublas directly.
244       const uint64 m = in_batch * in_planes * in_rows * in_cols;
245       const uint64 k = in_depth;
246       const uint64 n = out_depth;
247 
248       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
249                                   input.template flat<T>().size());
250       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
251                                   filter.template flat<T>().size());
252       auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
253                                   output->template flat<T>().size());
254 
255       auto no_transpose = se::blas::Transpose::kNoTranspose;
256       bool blas_launch_status =
257           stream
258               ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
259                              n, a_ptr, k, 0.0f, &c_ptr, n)
260               .ok();
261       if (!blas_launch_status) {
262         ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
263                                         ", n=", n, ", k=", k));
264       }
265       return;
266     } else if (filter_planes == in_planes && filter_rows == in_rows &&
267                filter_cols == in_cols && padding == Padding::VALID &&
268                data_format == FORMAT_NHWC) {
269       // The input data and filter have the same planes/height/width, so call
270       // cublas directly.
271       const uint64 m = in_batch;
272       const uint64 k = in_planes * in_rows * in_cols * in_depth;
273       const uint64 n = out_depth;
274 
275       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
276                                   input.template flat<T>().size());
277       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
278                                   filter.template flat<T>().size());
279       auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
280                                   output->template flat<T>().size());
281 
282       auto no_transpose = se::blas::Transpose::kNoTranspose;
283       bool blas_launch_status =
284           stream
285               ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
286                              n, a_ptr, k, 0.0f, &c_ptr, n)
287               .ok();
288       if (!blas_launch_status) {
289         ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
290                                         ", n=", n, ", k=", k));
291       }
292       return;
293     }
294 
295     if (padding == Padding::SAME) {
296       const bool rows_odd = (pad_rows % 2 != 0);
297       const bool cols_odd = (pad_cols % 2 != 0);
298       const bool planes_odd = (pad_planes % 2 != 0);
299 
300       // Necessary because cuDNN only supports symmetric padding.
301       // TODO(mjanusz): Consider making this optional? This would save some
302       // overhead and would work as long as an op trained this way is only
303       // used on GPU.
304       if (rows_odd || cols_odd || planes_odd) {
305         const int64 new_in_rows = in_rows + rows_odd;
306         const int64 new_in_cols = in_cols + cols_odd;
307         const int64 new_in_planes = in_planes + planes_odd;
308 
309         Tensor transformed_input;
310         TensorShape transformed_shape = ShapeFromFormat(
311             data_format, in_batch, {{new_in_planes, new_in_rows, new_in_cols}},
312             in_depth);
313         OP_REQUIRES_OK(
314             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape,
315                                     &transformed_input));
316 
317         functor::PadInput<GPUDevice, T, int, 5>()(
318             ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()),
319             {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}},
320             To32Bit(transformed_input.tensor<T, 5>()), data_format);
321         input = transformed_input;
322         in_rows = new_in_rows;
323         in_cols = new_in_cols;
324         in_planes = new_in_planes;
325       }
326     }
327 
328     if (data_format == FORMAT_NHWC) {
329       const TensorShape nchw_shape = ShapeFromFormat(
330           FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
331       if (in_depth > 1) {
332         Tensor transformed_input;
333         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
334                                                nchw_shape, &transformed_input));
335         // input: [b, x, y, z, d]
336         // t_input: [b, d, x, y, z]
337         // NCDHW is the only format universally supported by cuDNN.
338         functor::NHWCToNCHW<GPUDevice, T, 5>()(
339             ctx->eigen_device<GPUDevice>(),
340             const_cast<const Tensor&>(input).tensor<T, 5>(),
341             transformed_input.tensor<T, 5>());
342         input = transformed_input;
343       } else {
344         CHECK(input.CopyFrom(input, nchw_shape));
345       }
346     }
347 
348     CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
349         << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
350         << pad_planes << ")";
351     se::dnn::BatchDescriptor input_desc(3);
352     input_desc.set_count(in_batch)
353         .set_feature_map_count(in_depth)
354         .set_spatial_dim(DimIndex::X, in_cols)
355         .set_spatial_dim(DimIndex::Y, in_rows)
356         .set_spatial_dim(DimIndex::Z, in_planes)
357         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
358     se::dnn::BatchDescriptor output_desc(3);
359     output_desc.set_count(in_batch)
360         .set_spatial_dim(DimIndex::X, out_cols)
361         .set_spatial_dim(DimIndex::Y, out_rows)
362         .set_spatial_dim(DimIndex::Z, out_planes)
363         .set_feature_map_count(out_depth)
364         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
365     se::dnn::FilterDescriptor filter_desc(3);
366     filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
367         .set_spatial_dim(DimIndex::Y, filter_rows)
368         .set_spatial_dim(DimIndex::Z, filter_planes)
369         .set_input_feature_map_count(in_depth)
370         .set_output_feature_map_count(out_depth);
371     se::dnn::ConvolutionDescriptor conv_desc(3);
372     conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
373         .set_dilation_rate(DimIndex::Y, dilations[1])
374         .set_dilation_rate(DimIndex::Z, dilations[0])
375         .set_filter_stride(DimIndex::X, strides[2])
376         .set_filter_stride(DimIndex::Y, strides[1])
377         .set_filter_stride(DimIndex::Z, strides[0])
378         .set_zero_padding(DimIndex::X, pad_cols / 2)
379         .set_zero_padding(DimIndex::Y, pad_rows / 2)
380         .set_zero_padding(DimIndex::Z, pad_planes / 2);
381 
382     Tensor transformed_filter;
383     OP_REQUIRES_OK(
384         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
385                                 TensorShape({out_depth, in_depth, filter_planes,
386                                              filter_rows, filter_cols}),
387                                 &transformed_filter));
388     // filter: [x, y, z, in, out]
389     // t_filter: [out, in, x, y, z]
390     functor::TransformFilter<GPUDevice, T, int, 5>()(
391         ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
392         To32Bit(filter.tensor<T, 5>()),
393         To32Bit(transformed_filter.tensor<T, 5>()));
394 
395     Tensor transformed_output;
396     OP_REQUIRES_OK(
397         ctx, ctx->allocate_temp(
398                  DataTypeToEnum<T>::value,
399                  ShapeFromFormat(FORMAT_NCHW, in_batch,
400                                  {{out_planes, out_rows, out_cols}}, out_depth),
401                  &transformed_output));
402 
403     auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
404                                     input.template flat<T>().size());
405     auto filter_ptr =
406         AsDeviceMemory(transformed_filter.template flat<T>().data(),
407                        transformed_filter.template flat<T>().size());
408     auto output_ptr =
409         AsDeviceMemory(transformed_output.template flat<T>().data(),
410                        transformed_output.template flat<T>().size());
411 
412     static int64 ConvolveScratchSize = GetDnnWorkspaceLimit(
413         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
414 
415     int device_id = stream->parent()->device_ordinal();
416     DataType dtype = input.dtype();
417     ConvParameters conv_parameters = {
418         in_batch,
419         in_depth,
420         {{in_planes, in_rows, in_cols}},
421         FORMAT_NCHW,
422         out_depth,
423         {{filter_planes, filter_rows, filter_cols}},
424         {{dilations[0], dilations[1], dilations[2]}},
425         {{strides[0], strides[1], strides[2]}},
426         {{pad_planes, pad_rows, pad_cols}},
427         dtype,
428         device_id,
429     };
430 
431     using se::dnn::AlgorithmConfig;
432     using se::dnn::AlgorithmDesc;
433     using se::dnn::ProfileResult;
434 
435     AlgorithmConfig algorithm_config;
436 
437     if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
438                                   conv_parameters, &algorithm_config)) {
439       std::vector<AlgorithmDesc> algorithms;
440       OP_REQUIRES(ctx,
441                   stream->parent()->GetConvolveAlgorithms(
442                       conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
443                           stream->parent()),
444                       &algorithms),
445                   errors::Unknown(
446                       "Failed to get convolution algorithm. This is probably "
447                       "because cuDNN failed to initialize, so try looking to "
448                       "see if a warning log message was printed above."));
449 
450       std::vector<tensorflow::AutotuneResult> results;
451       for (auto profile_algorithm : algorithms) {
452         // TODO(zhengxq): profile each algorithm multiple times to better
453         // accuracy.
454         DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
455         ProfileResult profile_result;
456         bool cudnn_launch_status =
457             stream
458                 ->ThenConvolveWithAlgorithm(
459                     input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
460                     output_desc, &output_ptr, &scratch_allocator,
461                     AlgorithmConfig(profile_algorithm), &profile_result)
462                 .ok();
463         if (cudnn_launch_status) {
464           if (profile_result.is_valid()) {
465             results.emplace_back();
466             auto& result = results.back();
467             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
468             result.mutable_conv()->set_tensor_ops_enabled(
469                 profile_algorithm.tensor_ops_enabled());
470             result.mutable_success()->set_scratch_bytes(
471                 scratch_allocator.TotalByteSize());
472             *result.mutable_success()->mutable_run_time() =
473                 proto_utils::ToDurationProto(
474                     absl::Milliseconds(profile_result.elapsed_time_in_ms()));
475           }
476         }
477       }
478       LogConvAutotuneResults(ctx->op_kernel().def(), input, filter, *output,
479                              stream->parent(), results);
480       OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
481       AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
482     }
483 
484     DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
485     bool cudnn_launch_status =
486         stream
487             ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
488                                         filter_ptr, conv_desc, output_desc,
489                                         &output_ptr, &scratch_allocator,
490                                         algorithm_config, nullptr)
491             .ok();
492 
493     if (!cudnn_launch_status) {
494       ctx->SetStatus(errors::Internal(
495           "cuDNN launch failure : input shape(", input.shape().DebugString(),
496           ") filter shape(", filter.shape().DebugString(), ")"));
497     }
498 
499     if (data_format == FORMAT_NHWC) {
500       // t_output: [b, out, x, y, z]
501       // output: [b, x, y, z, out]
502       functor::NCHWToNHWC<GPUDevice, T, 5>()(
503           ctx->eigen_device<GPUDevice>(),
504           const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
505           output->tensor<T, 5>());
506     } else {
507       *output = transformed_output;
508     }
509   }
510 };
511 
512 // Forward declarations of the functor specializations for GPU.
513 // This ensures that the custom implementation is used instead of the default
514 // Eigen one (which is used for CPU).
515 namespace functor {
516 #define DECLARE_GPU_SPEC(T)                                           \
517   template <>                                                         \
518   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
519       const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
520       typename TTypes<T, 5, int>::ConstTensor in,                     \
521       typename TTypes<T, 5, int>::Tensor out);                        \
522   template <>                                                         \
523   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
524       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
525       typename TTypes<T, 5>::Tensor out);                             \
526   template <>                                                         \
527   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
528       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
529       const std::array<int, 3>& padding_left,                         \
530       const std::array<int, 3>& padding_right,                        \
531       typename TTypes<T, 5, int>::Tensor out, TensorFormat format);   \
532   template <>                                                         \
533   void NHWCToNCHW<GPUDevice, T, 5>::operator()(                       \
534       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
535       typename TTypes<T, 5>::Tensor out);                             \
536   template <>                                                         \
537   void NCHWToNHWC<GPUDevice, T, 5>::operator()(                       \
538       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
539       typename TTypes<T, 5>::Tensor out);
540 
541 DECLARE_GPU_SPEC(Eigen::half);
542 DECLARE_GPU_SPEC(float);
543 DECLARE_GPU_SPEC(double);
544 #undef DECLARE_GPU_SPEC
545 
546 }  // namespace functor
547 
548 // Registration of the GPU implementations.
549 REGISTER_KERNEL_BUILDER(
550     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
551     Conv3DOp<GPUDevice, Eigen::half>);
552 REGISTER_KERNEL_BUILDER(
553     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
554     Conv3DOp<GPUDevice, float>);
555 REGISTER_KERNEL_BUILDER(
556     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
557     Conv3DOp<GPUDevice, double>);
558 #endif  // GOOGLE_CUDA
559 
560 }  // namespace tensorflow
561