• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/bias_op.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/redux_functor.h"
28 #include "tensorflow/core/util/tensor_format.h"
29 
30 #if GOOGLE_CUDA
31 #include "tensorflow/core/kernels/bias_op_gpu.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
34 #endif  // GOOGLE_CUDA
35 
36 namespace tensorflow {
37 
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 typedef Eigen::GpuDevice GPUDevice;
40 #ifdef TENSORFLOW_USE_SYCL
41 typedef Eigen::SyclDevice SYCLDevice;
42 #endif  // TENSORFLOW_USE_SYCL
43 
44 namespace {
45 
GetBiasValueDims(const Tensor & value_tensor,TensorFormat data_format,int32 * batch,int32 * height,int32 * width,int32 * depth,int32 * channel)46 void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
47                       int32* batch, int32* height, int32* width, int32* depth,
48                       int32* channel) {
49   *batch = 1;
50   *height = 1;
51   *width = 1;
52   *depth = 1;
53   *channel = 1;
54   if (data_format == FORMAT_NHWC) {
55     int32 channel_dim = value_tensor.dims() - 1;
56     *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
57     for (int32 i = 0; i < channel_dim; i++) {
58       *batch *= static_cast<int32>(value_tensor.dim_size(i));
59     }
60   } else if (data_format == FORMAT_NCHW) {
61     *batch = static_cast<int32>(value_tensor.dim_size(0));
62     *channel = static_cast<int32>(value_tensor.dim_size(1));
63     *height = static_cast<int32>(value_tensor.dim_size(2));
64     if (value_tensor.dims() > 3) {
65       *width = static_cast<int32>(value_tensor.dim_size(3));
66     }
67     if (value_tensor.dims() > 4) {
68       *depth = static_cast<int32>(value_tensor.dim_size(4));
69     }
70   }
71 }
72 
73 template <class T>
74 struct AccumulatorType {
75   typedef T type;
76 };
77 
78 // float is faster on the CPU than half, and also more precise,
79 // so use float for the temporary accumulators.
80 template <>
81 struct AccumulatorType<Eigen::half> {
82   typedef float type;
83 };
84 
85 }  // namespace
86 
87 template <typename Device, typename T>
88 class BiasOp : public BinaryOp<T> {
89  public:
BiasOp(OpKernelConstruction * context)90   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
91     string data_format;
92     if (context->GetAttr("data_format", &data_format).ok()) {
93       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
94                   errors::InvalidArgument("Invalid data format"));
95     } else {
96       data_format_ = FORMAT_NHWC;
97     }
98   }
99 
Compute(OpKernelContext * context)100   void Compute(OpKernelContext* context) override {
101     const Tensor& input = context->input(0);
102     const Tensor& bias = context->input(1);
103 
104     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
105                 errors::InvalidArgument("Input tensor must be at least 2D: ",
106                                         input.shape().DebugString()));
107     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
108                 errors::InvalidArgument("Biases must be 1D: ",
109                                         bias.shape().DebugString()));
110 
111     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
112     size_t channel_dim;
113     if (data_format_ == FORMAT_NCHW) {
114       channel_dim = 1;  // NCHW always have channel dim in 1 (with 3, 4, 5
115                         // dimensions data).
116     } else {
117       channel_dim = input.shape().dims() - 1;  // End of code by intel_tf.
118     }
119 
120     OP_REQUIRES(
121         context,
122         bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
123         errors::InvalidArgument(
124             "Must provide as many biases as the last dimension "
125             "of the input tensor: ",
126             bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
127 
128     Tensor* output = nullptr;
129     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
130                                 {0}, 0, input.shape(), &output));
131     if (input.NumElements() == 0) return;
132 
133     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
134     if (data_format_ == FORMAT_NCHW) {
135       int32 batch, height, width, depth, channel;
136       GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
137                        &channel);
138       switch (input.shape().dims()) {
139         case 3: {
140           Eigen::DSizes<int32, 3> three_dims(1, channel, 1);
141           Eigen::DSizes<int32, 3> broad_cast_dims(batch, 1, height);
142           const Device& d = context->eigen_device<Device>();
143           output->tensor<T, 3>().device(d) =
144               input.tensor<T, 3>() + bias.tensor<T, 1>()
145                                          .reshape(three_dims)
146                                          .broadcast(broad_cast_dims);
147         } break;
148         case 4: {
149           Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
150           Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
151           const Device& d = context->eigen_device<Device>();
152           output->tensor<T, 4>().device(d) =
153               input.tensor<T, 4>() +
154               bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
155         } break;
156         case 5: {
157           Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1);
158           Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
159                                                   depth);
160           const Device& d = context->eigen_device<Device>();
161           output->tensor<T, 5>().device(d) =
162               input.tensor<T, 5>() +
163               bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims);
164         } break;
165         default:
166           OP_REQUIRES(context, false,
167                       errors::InvalidArgument("Only ranks up to 5 supported: ",
168                                               input.shape().DebugString()));
169       }
170       return;
171     }  // End of code by intel_tf.
172 
173     switch (input.shape().dims()) {
174       case 2:
175         Compute<2>(context, input, bias, output);
176         break;
177       case 3:
178         Compute<3>(context, input, bias, output);
179         break;
180       case 4:
181         Compute<4>(context, input, bias, output);
182         break;
183       case 5:
184         Compute<5>(context, input, bias, output);
185         break;
186       default:
187         OP_REQUIRES(context, false,
188                     errors::InvalidArgument("Only ranks up to 5 supported: ",
189                                             input.shape().DebugString()));
190     }
191   }
192 
193   // Add biases for an input matrix of rank Dims, by using the Bias.
194   template <int Dims>
Compute(OpKernelContext * ctx,const Tensor & input,const Tensor & bias,Tensor * output)195   void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias,
196                Tensor* output) {
197     functor::Bias<Device, T, Dims> functor;
198     functor(ctx->eigen_device<Device>(), input.tensor<T, Dims>(), bias.vec<T>(),
199             output->tensor<T, Dims>());
200   }
201 
202  private:
203   TensorFormat data_format_;
204 };
205 
206 #define REGISTER_KERNEL(type)                                         \
207   REGISTER_KERNEL_BUILDER(                                            \
208       Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
209       BiasOp<CPUDevice, type>);                                       \
210   REGISTER_KERNEL_BUILDER(                                            \
211       Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
212       BiasOp<CPUDevice, type>);
213 
214 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
215 #undef REGISTER_KERNEL
216 
217 #ifdef TENSORFLOW_USE_SYCL
218 #define REGISTER_KERNEL(type)                                          \
219   REGISTER_KERNEL_BUILDER(                                             \
220       Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"),   \
221       BiasOp<SYCLDevice, type>);                                       \
222   REGISTER_KERNEL_BUILDER(                                             \
223       Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
224       BiasOp<SYCLDevice, type>);
225 
226 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
227 REGISTER_KERNEL(float);
228 REGISTER_KERNEL(double);
229 #undef REGISTER_KERNEL
230 #endif  // TENSORFLOW_USE_SYCL
231 
232 template <typename Device, typename T>
233 class BiasGradOp : public OpKernel {
234  public:
BiasGradOp(OpKernelConstruction * context)235   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
236     string data_format;
237     if (context->GetAttr("data_format", &data_format).ok()) {
238       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
239                   errors::InvalidArgument("Invalid data format"));
240     } else {
241       data_format_ = FORMAT_NHWC;
242     }
243   }
244 
Compute(OpKernelContext * context)245   void Compute(OpKernelContext* context) override {
246     const Tensor& output_backprop = context->input(0);
247 
248     OP_REQUIRES(context,
249                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
250                 errors::InvalidArgument("Input tensor must be at least 2D: ",
251                                         output_backprop.shape().DebugString()));
252 
253     OP_REQUIRES(
254         context,
255         FastBoundsCheck(output_backprop.NumElements(),
256                         std::numeric_limits<int32>::max()),
257         errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
258 
259     int32 batch, height, width, depth, channel;
260     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
261                      &depth, &channel);
262     Tensor* output = nullptr;
263     TensorShape output_shape{channel};
264     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
265 
266     if (channel == 0) {
267       return;  // Nothing to do
268     } else if (output_backprop.NumElements() == 0) {
269       // Eigen often crashes by design on empty tensors, but setZero is safe
270       output->template flat<T>().setZero();
271     } else {
272       // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
273       if (data_format_ == FORMAT_NCHW) {
274         Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
275                                                   height * width * depth);
276 #ifdef EIGEN_HAS_INDEX_LIST
277         using idx0 = Eigen::type2index<0>;
278         using idx2 = Eigen::type2index<2>;
279         Eigen::IndexList<idx0, idx2> reduction_axes;
280 #else
281         Eigen::array<Eigen::Index, 2> reduction_axes = {0, 2};
282 #endif
283         output->template flat<T>().device(context->eigen_device<Device>()) =
284             output_backprop.flat<T>()
285                 .template cast<typename AccumulatorType<T>::type>()
286                 .reshape(three_dims)
287                 .sum(reduction_axes)
288                 .template cast<T>();  // End of code by intel_tf.
289       } else {
290         using AccumT = typename AccumulatorType<T>::type;
291         const functor::ReduceOuterDimensions<
292             T, AccumT, Eigen::internal::scalar_sum_op<AccumT>>
293             redux;
294 
295         Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
296                                                 channel);
297         redux(context->eigen_device<Device>(), two_dims, output_backprop,
298               output);
299       }
300     }
301   }
302 
303  private:
304   TensorFormat data_format_;
305 };
306 
307 // Registration of the GPU implementations.
308 #define REGISTER_KERNEL(type)                                           \
309   REGISTER_KERNEL_BUILDER(                                              \
310       Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
311       BiasGradOp<CPUDevice, type>);
312 
313 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
314 #undef REGISTER_KERNEL
315 
316 #ifdef TENSORFLOW_USE_SYCL
317 #define REGISTER_KERNEL(type)                                            \
318   REGISTER_KERNEL_BUILDER(                                               \
319       Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
320       BiasGradOp<SYCLDevice, type>);
321 
322 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
323 REGISTER_KERNEL(float);
324 REGISTER_KERNEL(double);
325 #undef REGISTER_KERNEL
326 #endif  // TENSORFLOW_USE_SYCL
327 
328 #if GOOGLE_CUDA
329 template <typename T>
330 class BiasOp<GPUDevice, T> : public BinaryOp<T> {
331  public:
332   typedef GPUDevice Device;
BiasOp(OpKernelConstruction * context)333   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
334     string data_format;
335     if (context->GetAttr("data_format", &data_format).ok()) {
336       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
337                   errors::InvalidArgument("Invalid data format"));
338     } else {
339       data_format_ = FORMAT_NHWC;
340     }
341   }
342 
Compute(OpKernelContext * context)343   void Compute(OpKernelContext* context) override {
344     const Tensor& input = context->input(0);
345     const Tensor& bias = context->input(1);
346 
347     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
348                 errors::InvalidArgument("Input tensor must be at least 2D: ",
349                                         input.shape().DebugString()));
350     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
351                 errors::InvalidArgument("Biases must be 1D: ",
352                                         bias.shape().DebugString()));
353     int32 batch, height, width, depth, channel;
354     GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
355                      &channel);
356     OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
357                 errors::InvalidArgument(
358                     "Must provide as many biases as the channel dimension "
359                     "of the input tensor: ",
360                     bias.shape().DebugString(), " vs. ", channel, " in ",
361                     input.shape().DebugString()));
362     Tensor* output = nullptr;
363     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
364                                 {0}, 0, input.shape(), &output));
365     if (input.NumElements() > 0) {
366       BiasGPU<T>::compute(context->template eigen_device<Device>(),
367                           input.flat<T>().data(), bias.flat<T>().data(),
368                           output->flat<T>().data(), batch, width, height, depth,
369                           channel, data_format_);
370     }
371   }
372 
373  private:
374   TensorFormat data_format_;
375 };
376 
377 // Registration of the GPU implementations.
378 #define REGISTER_GPU_KERNEL(type)                                     \
379   REGISTER_KERNEL_BUILDER(                                            \
380       Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
381       BiasOp<GPUDevice, type>);                                       \
382   REGISTER_KERNEL_BUILDER(                                            \
383       Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
384       BiasOp<GPUDevice, type>);
385 
386 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
387 #undef REGISTER_GPU_KERNEL
388 
389 struct BiasGradAutotuneGroup {
nametensorflow::BiasGradAutotuneGroup390   static string name() { return "BiasGrad"; }
391 };
392 
393 class BiasAddGradGPUConfig {
394  public:
BiasAddGradGPUConfig()395   BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
ToString() const396   string ToString() const {
397     if (mode_ == BiasAddGradGPUMode::kNative) {
398       return "native CUDA kernel.";
399     }
400     if (mode_ == BiasAddGradGPUMode::kReduction) {
401       return "cub reduction kernel.";
402     }
403     return "unknown kernel.";
404   }
get_mode() const405   BiasAddGradGPUMode get_mode() const { return mode_; }
set_mode(BiasAddGradGPUMode val)406   void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
407 
operator ==(const BiasAddGradGPUConfig & other) const408   bool operator==(const BiasAddGradGPUConfig& other) const {
409     return this->mode_ == other.get_mode();
410   }
411 
operator !=(const BiasAddGradGPUConfig & other) const412   bool operator!=(const BiasAddGradGPUConfig& other) const {
413     return !(*this == other);
414   }
415 
416  private:
417   BiasAddGradGPUMode mode_;
418 };
419 
420 // Encapsulate all the shape information that is used in bias add grad
421 // operations.
422 class BiasAddParams {
423  public:
424   // We use a list to maintain both the shape value and the order (data format).
425   using SpatialArray = gtl::InlinedVector<int64, 4>;
BiasAddParams(const SpatialArray & in_shape,TensorFormat data_format,DataType dtype,int device_id)426   BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
427                 DataType dtype, int device_id)
428       : in_shape_(in_shape),
429         data_format_(data_format),
430         dtype_(dtype),
431         device_id_(device_id) {
432     for (int64 val : in_shape_) {
433       hash_code_ = Hash64Combine(hash_code_, val);
434     }
435     hash_code_ = Hash64Combine(hash_code_, data_format);
436     hash_code_ = Hash64Combine(hash_code_, dtype);
437     hash_code_ = Hash64Combine(hash_code_, device_id);
438   }
operator ==(const BiasAddParams & other) const439   bool operator==(const BiasAddParams& other) const {
440     return this->get_data_as_tuple() == other.get_data_as_tuple();
441   }
442 
operator !=(const BiasAddParams & other) const443   bool operator!=(const BiasAddParams& other) const {
444     return !(*this == other);
445   }
hash() const446   uint64 hash() const { return hash_code_; }
447 
ToString() const448   string ToString() const {
449     // clang-format off
450     return strings::StrCat(
451         "(", str_util::Join(in_shape_, ", "), "), ",
452         data_format_, ", ", dtype_, ", ", device_id_);
453     // clang-format on
454   }
455 
456  protected:
457   using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
458 
get_data_as_tuple() const459   ParamsDataType get_data_as_tuple() const {
460     return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
461   }
462 
463   uint64 hash_code_ = 0;
464 
465  private:
466   SpatialArray in_shape_;
467   TensorFormat data_format_;
468   DataType dtype_;
469   int device_id_;
470 };
471 
472 typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
473                           BiasAddGradGPUConfig>
474     AutotuneBiasGrad;
475 
476 template <typename T>
477 class BiasGradOp<GPUDevice, T> : public OpKernel {
478  public:
479   typedef GPUDevice Device;
BiasGradOp(OpKernelConstruction * context)480   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
481     string data_format;
482     if (context->GetAttr("data_format", &data_format).ok()) {
483       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
484                   errors::InvalidArgument("Invalid data format"));
485     } else {
486       data_format_ = FORMAT_NCHW;
487     }
488   }
489 
ComputeWithCustomKernel(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)490   void ComputeWithCustomKernel(OpKernelContext* context,
491                                const Tensor& output_backprop, int32 batch,
492                                int32 width, int32 height, int32 depth,
493                                int32 channel, Tensor* output) {
494     BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
495                             output_backprop.template flat<T>().data(),
496                             output->flat<T>().data(), batch, width, height,
497                             depth, channel, data_format_);
498   }
499 
ComputeWithReduceSum(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)500   void ComputeWithReduceSum(OpKernelContext* context,
501                             const Tensor& output_backprop, int32 batch,
502                             int32 width, int32 height, int32 depth,
503                             int32 channel, Tensor* output) {
504     if (data_format_ == FORMAT_NCHW) {
505       int32 row_count = batch * channel;
506       int32 col_count = height * width * depth;
507       Tensor temp_grad_outputs;
508       // For 'NCHW' format, we perform reduction twice: first HW, then N.
509       TensorShape temp_grad_output_shape{row_count, col_count};
510       OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
511                                                      temp_grad_output_shape,
512                                                      &temp_grad_outputs));
513       BiasGradGPU<T>::DoRowReduction(
514           context, temp_grad_outputs.flat<T>().data(),
515           output_backprop.template flat<T>().data(), row_count, col_count);
516 
517       row_count = batch;
518       col_count = channel;
519       BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
520                                      temp_grad_outputs.flat<T>().data(),
521                                      row_count, col_count);
522     } else {
523       // For 'NHWC', we simply apply reduction once on NHW.
524       int32 row_count = batch * height * width * depth;
525       int32 col_count = channel;
526       BiasGradGPU<T>::DoColReduction(
527           context, const_cast<T*>(output->flat<T>().data()),
528           reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
529           row_count, col_count);
530     }
531   }
532 
Compute(OpKernelContext * context)533   void Compute(OpKernelContext* context) override {
534     const Tensor& output_backprop = context->input(0);
535 
536     OP_REQUIRES(context,
537                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
538                 errors::InvalidArgument("Input tensor must be at least 2D: ",
539                                         output_backprop.shape().DebugString()));
540     int32 batch, height, width, depth, channel;
541     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
542                      &depth, &channel);
543     Tensor* output = nullptr;
544     TensorShape output_shape{channel};
545     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
546     if (channel == 0) return;
547     auto* stream = context->op_device_context()->stream();
548     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
549     se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
550                                     output->NumElements() * sizeof(T));
551     stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
552     if (output_backprop.NumElements() <= 0) return;
553 
554     int device_id = stream->parent()->device_ordinal();
555     DataType dtype = output_backprop.dtype();
556     BiasAddParams bias_parameters = {
557         {batch, height * width * depth, channel},
558         data_format_,
559         dtype,
560         device_id,
561     };
562 
563     // Autotune two algorithm: customized
564     BiasAddGradGPUConfig algo_config;
565     if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
566       BiasGradGPUProfileResult best_result;
567       // Initialize the timer.
568       perftools::gputools::Timer timer(stream->parent());
569       stream->InitTimer(&timer);
570       stream->ThenStartTimer(&timer);
571       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
572                               depth, channel, output);
573       stream->ThenStopTimer(&timer);
574       uint64 elapsed_microseconds = timer.Microseconds();
575       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
576               << " Native algo latency: " << elapsed_microseconds;
577       if (elapsed_microseconds < best_result.elapsed_time()) {
578         best_result.set_algorithm(BiasAddGradGPUMode::kNative);
579         best_result.set_elapsed_time(elapsed_microseconds);
580       }
581 
582       // Try reduction and profile.
583       stream->ThenStartTimer(&timer);
584       ComputeWithReduceSum(context, output_backprop, batch, width, height,
585                            depth, channel, output);
586       stream->ThenStopTimer(&timer);
587 
588       elapsed_microseconds = timer.Microseconds();
589       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
590               << " Reduction algo latency: " << elapsed_microseconds;
591       if (elapsed_microseconds < best_result.elapsed_time()) {
592         best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
593         best_result.set_elapsed_time(elapsed_microseconds);
594       }
595 
596       algo_config.set_mode(best_result.algorithm());
597       AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
598 
599       // Results are already available during autotune, so no need to continue.
600       return;
601     }
602 
603     // Choose the best algorithm based on autotune results.
604     if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
605       ComputeWithReduceSum(context, output_backprop, batch, width, height,
606                            depth, channel, output);
607     } else {
608       // Default to the customized kernel.
609       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
610                               depth, channel, output);
611     }
612   }
613 
614  private:
615   TensorFormat data_format_;
616 };
617 
618 // Registration of the GPU implementations.
619 #define REGISTER_GPU_KERNEL(type)                                       \
620   REGISTER_KERNEL_BUILDER(                                              \
621       Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
622       BiasGradOp<GPUDevice, type>);
623 
624 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
625 #undef REGISTER_GPU_KERNEL
626 
627 #endif  // GOOGLE_CUDA
628 
629 }  // namespace tensorflow
630