• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/bias_op.h"
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/kernels/redux_functor.h"
29 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
30 #include "tensorflow/core/util/determinism.h"
31 #include "tensorflow/core/util/tensor_format.h"
32 
33 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34 #include "tensorflow/core/kernels/bias_op_gpu.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 #if GOOGLE_CUDA
38 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
39 #endif  // GOOGLE_CUDA
40 
41 namespace tensorflow {
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 typedef Eigen::GpuDevice GPUDevice;
45 
46 namespace {
47 
GetBiasValueDims(const Tensor & value_tensor,TensorFormat data_format,int32 * batch,int32 * height,int32 * width,int32 * depth,int32 * channel)48 void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
49                       int32* batch, int32* height, int32* width, int32* depth,
50                       int32* channel) {
51   *batch = 1;
52   *height = 1;
53   *width = 1;
54   *depth = 1;
55   *channel = 1;
56   if (data_format == FORMAT_NHWC) {
57     int32_t channel_dim = value_tensor.dims() - 1;
58     *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
59     for (int32_t i = 0; i < channel_dim; i++) {
60       *batch *= static_cast<int32>(value_tensor.dim_size(i));
61     }
62   } else if (data_format == FORMAT_NCHW) {
63     *batch = static_cast<int32>(value_tensor.dim_size(0));
64     *channel = static_cast<int32>(value_tensor.dim_size(1));
65     *height = static_cast<int32>(value_tensor.dim_size(2));
66     if (value_tensor.dims() > 3) {
67       *width = static_cast<int32>(value_tensor.dim_size(3));
68     }
69     if (value_tensor.dims() > 4) {
70       *depth = static_cast<int32>(value_tensor.dim_size(4));
71     }
72   }
73 }
74 
75 template <class T>
76 struct AccumulatorType {
77   typedef T type;
78 };
79 
80 // float is faster on the CPU than half, and also more precise,
81 // so use float for the temporary accumulators.
82 template <>
83 struct AccumulatorType<Eigen::half> {
84   typedef float type;
85 };
86 
87 }  // namespace
88 
89 template <typename Device, typename T>
90 class BiasOp : public BinaryOp<T> {
91  public:
BiasOp(OpKernelConstruction * context)92   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
93     string data_format;
94     if (context->GetAttr("data_format", &data_format).ok()) {
95       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
96                   errors::InvalidArgument("Invalid data format"));
97     } else {
98       data_format_ = FORMAT_NHWC;
99     }
100   }
101 
Compute(OpKernelContext * context)102   void Compute(OpKernelContext* context) override {
103     const Tensor& input = context->input(0);
104     const Tensor& bias = context->input(1);
105 
106     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
107                 errors::InvalidArgument("Input tensor must be at least 2D: ",
108                                         input.shape()));
109     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
110                 errors::InvalidArgument("Biases must be 1D: ", bias.shape()));
111 
112     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
113     int channel_dim;
114     if (data_format_ == FORMAT_NCHW) {
115       channel_dim = 1;  // NCHW always have channel dim in 1 (with 3, 4, 5
116                         // dimensions data).
117     } else {
118       channel_dim = input.shape().dims() - 1;  // End of code by intel_tf.
119     }
120 
121     OP_REQUIRES(context,
122                 bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
123                 errors::InvalidArgument(
124                     "Must provide as many biases as the last dimension "
125                     "of the input tensor: ",
126                     bias.shape(), " vs. ", input.shape()));
127 
128     Tensor* output = nullptr;
129     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
130                                 {0}, 0, input.shape(), &output));
131     if (input.NumElements() == 0) return;
132 
133     functor::Bias<Device, T> functor;
134     const Device& d = context->eigen_device<Device>();
135     if (data_format_ == FORMAT_NCHW && input.shape().dims() > 2) {
136       functor(d, input.flat_inner_outer_dims<T, 2>(1),
137               bias.flat_outer_dims<T, 2>(),
138               output->flat_inner_outer_dims<T, 2>(1));
139     } else {
140       functor(d, input.flat<T>(), bias.vec<T>(), output->flat<T>());
141     }
142   }
143 
144  private:
145   TensorFormat data_format_;
146 };
147 
148 #define REGISTER_KERNEL(type)                                         \
149   REGISTER_KERNEL_BUILDER(                                            \
150       Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
151       BiasOp<CPUDevice, type>);                                       \
152   REGISTER_KERNEL_BUILDER(                                            \
153       Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
154       BiasOp<CPUDevice, type>);
155 
156 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
157 #undef REGISTER_KERNEL
158 
159 template <typename Device, typename T>
160 class BiasGradOp : public OpKernel {
161  public:
BiasGradOp(OpKernelConstruction * context)162   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
163     string data_format;
164     if (context->GetAttr("data_format", &data_format).ok()) {
165       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
166                   errors::InvalidArgument("Invalid data format"));
167     } else {
168       data_format_ = FORMAT_NHWC;
169     }
170   }
171 
Compute(OpKernelContext * context)172   void Compute(OpKernelContext* context) override {
173     const Tensor& output_backprop = context->input(0);
174 
175     OP_REQUIRES(context,
176                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
177                 errors::InvalidArgument("Input tensor must be at least 2D: ",
178                                         output_backprop.shape()));
179 
180     OP_REQUIRES(
181         context,
182         FastBoundsCheck(output_backprop.NumElements(),
183                         std::numeric_limits<int32>::max()),
184         errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
185 
186     int channel_dim;
187     if (data_format_ == FORMAT_NCHW) {
188       channel_dim = 1;
189     } else {
190       channel_dim = output_backprop.shape().dims() - 1;
191     }
192     Tensor* output = nullptr;
193     TensorShape output_shape{output_backprop.shape().dim_size(channel_dim)};
194     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
195 
196     if (output_backprop.NumElements() == 0) {
197       // Eigen often crashes by design on empty tensors, but setZero is safe
198       output->template flat<T>().setZero();
199     } else {
200       // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
201       using AccumT = typename AccumulatorType<T>::type;
202       if (data_format_ == FORMAT_NCHW) {
203         const functor::ReduceMiddleDimensions<
204             T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>,
205             Eigen::internal::SumReducer<T>>
206             redux;
207 
208         auto flat_outer = output_backprop.flat_outer_dims<T, 3>();
209         redux(context->eigen_device<Device>(), flat_outer.dimensions(),
210               output_backprop, output, 1);
211       } else {
212         const functor::ReduceOuterDimensions<
213             T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>>
214             redux;
215 
216         auto flat_inner = output_backprop.flat_inner_dims<T, 2>();
217         redux(context->eigen_device<Device>(), flat_inner.dimensions(),
218               output_backprop, output);
219       }
220     }
221   }
222 
223  private:
224   TensorFormat data_format_;
225 };
226 
227 // Registration of the GPU implementations.
228 #define REGISTER_KERNEL(type)                                           \
229   REGISTER_KERNEL_BUILDER(                                              \
230       Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
231       BiasGradOp<CPUDevice, type>);
232 
233 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
234 #undef REGISTER_KERNEL
235 
236 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
237 template <typename T>
238 class BiasOp<GPUDevice, T> : public BinaryOp<T> {
239  public:
240   typedef GPUDevice Device;
BiasOp(OpKernelConstruction * context)241   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
242     string data_format;
243     if (context->GetAttr("data_format", &data_format).ok()) {
244       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
245                   errors::InvalidArgument("Invalid data format"));
246     } else {
247       data_format_ = FORMAT_NHWC;
248     }
249   }
250 
Compute(OpKernelContext * context)251   void Compute(OpKernelContext* context) override {
252     const Tensor& input = context->input(0);
253     const Tensor& bias = context->input(1);
254 
255     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
256                 errors::InvalidArgument("Input tensor must be at least 2D: ",
257                                         input.shape().DebugString()));
258     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
259                 errors::InvalidArgument("Biases must be 1D: ",
260                                         bias.shape().DebugString()));
261     int32_t batch, height, width, depth, channel;
262     GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
263                      &channel);
264     OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
265                 errors::InvalidArgument(
266                     "Must provide as many biases as the channel dimension "
267                     "of the input tensor: ",
268                     bias.shape().DebugString(), " vs. ", channel, " in ",
269                     input.shape().DebugString()));
270     Tensor* output = nullptr;
271     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
272                                 {0}, 0, input.shape(), &output));
273     if (input.NumElements() > 0) {
274       BiasGPU<T>::compute(context->template eigen_device<Device>(),
275                           input.flat<T>().data(), bias.flat<T>().data(),
276                           output->flat<T>().data(), batch, width, height, depth,
277                           channel, data_format_);
278     }
279   }
280 
281  private:
282   TensorFormat data_format_;
283 };
284 
285 // Registration of the GPU implementations.
286 #define REGISTER_GPU_KERNEL(type)                                     \
287   REGISTER_KERNEL_BUILDER(                                            \
288       Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
289       BiasOp<GPUDevice, type>);                                       \
290   REGISTER_KERNEL_BUILDER(                                            \
291       Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
292       BiasOp<GPUDevice, type>);
293 
294 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
295 REGISTER_GPU_KERNEL(int32);
296 #undef REGISTER_GPU_KERNEL
297 
298 struct BiasGradAutotuneGroup {
nametensorflow::BiasGradAutotuneGroup299   static string name() { return "BiasGrad"; }
300 };
301 
302 class BiasAddGradGPUConfig {
303  public:
BiasAddGradGPUConfig()304   BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
ToString() const305   string ToString() const {
306     if (mode_ == BiasAddGradGPUMode::kNative) {
307       return "native CUDA kernel.";
308     }
309     if (mode_ == BiasAddGradGPUMode::kReduction) {
310       return "cub reduction kernel.";
311     }
312     return "unknown kernel.";
313   }
get_mode() const314   BiasAddGradGPUMode get_mode() const { return mode_; }
set_mode(BiasAddGradGPUMode val)315   void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
316 
operator ==(const BiasAddGradGPUConfig & other) const317   bool operator==(const BiasAddGradGPUConfig& other) const {
318     return this->mode_ == other.get_mode();
319   }
320 
operator !=(const BiasAddGradGPUConfig & other) const321   bool operator!=(const BiasAddGradGPUConfig& other) const {
322     return !(*this == other);
323   }
324 
325  private:
326   BiasAddGradGPUMode mode_;
327 };
328 
329 // Encapsulate all the shape information that is used in bias add grad
330 // operations.
331 class BiasAddParams {
332  public:
333   // We use a list to maintain both the shape value and the order (data format).
334   using SpatialArray = gtl::InlinedVector<int64_t, 4>;
BiasAddParams(const SpatialArray & in_shape,TensorFormat data_format,DataType dtype,int device_id)335   BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
336                 DataType dtype, int device_id)
337       : in_shape_(in_shape),
338         data_format_(data_format),
339         dtype_(dtype),
340         device_id_(device_id) {
341     for (int64_t val : in_shape_) {
342       hash_code_ = Hash64Combine(hash_code_, val);
343     }
344     hash_code_ = Hash64Combine(hash_code_, data_format);
345     hash_code_ = Hash64Combine(hash_code_, dtype);
346     hash_code_ = Hash64Combine(hash_code_, device_id);
347   }
operator ==(const BiasAddParams & other) const348   bool operator==(const BiasAddParams& other) const {
349     return this->get_data_as_tuple() == other.get_data_as_tuple();
350   }
351 
operator !=(const BiasAddParams & other) const352   bool operator!=(const BiasAddParams& other) const {
353     return !(*this == other);
354   }
hash() const355   uint64 hash() const { return hash_code_; }
356 
ToString() const357   string ToString() const {
358     // clang-format off
359     return strings::StrCat(
360         "(", absl::StrJoin(in_shape_, ", "), "), ",
361         data_format_, ", ", dtype_, ", ", device_id_);
362     // clang-format on
363   }
364 
365  protected:
366   using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
367 
get_data_as_tuple() const368   ParamsDataType get_data_as_tuple() const {
369     return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
370   }
371 
372   uint64 hash_code_ = 0;
373 
374  private:
375   SpatialArray in_shape_;
376   TensorFormat data_format_;
377   DataType dtype_;
378   int device_id_;
379 };
380 
381 typedef AutotuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
382                           BiasAddGradGPUConfig>
383     AutotuneBiasGrad;
384 
385 template <typename T>
386 class BiasGradOp<GPUDevice, T> : public OpKernel {
387  public:
388   typedef GPUDevice Device;
BiasGradOp(OpKernelConstruction * context)389   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
390     string data_format;
391     if (context->GetAttr("data_format", &data_format).ok()) {
392       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
393                   errors::InvalidArgument("Invalid data format"));
394     } else {
395       data_format_ = FORMAT_NCHW;
396     }
397   }
398 
ComputeWithCustomKernel(OpKernelContext * context,const Tensor & output_backprop,int32_t batch,int32_t width,int32_t height,int32_t depth,int32_t channel,Tensor * output)399   void ComputeWithCustomKernel(OpKernelContext* context,
400                                const Tensor& output_backprop, int32_t batch,
401                                int32_t width, int32_t height, int32_t depth,
402                                int32_t channel, Tensor* output) {
403     BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
404                             output_backprop.template flat<T>().data(),
405                             output->flat<T>().data(), batch, width, height,
406                             depth, channel, data_format_);
407   }
408 
ComputeWithReduceSum(OpKernelContext * context,const Tensor & output_backprop,int32_t batch,int32_t width,int32_t height,int32_t depth,int32_t channel,Tensor * output)409   void ComputeWithReduceSum(OpKernelContext* context,
410                             const Tensor& output_backprop, int32_t batch,
411                             int32_t width, int32_t height, int32_t depth,
412                             int32_t channel, Tensor* output) {
413     if (data_format_ == FORMAT_NCHW) {
414       int32_t row_count = batch * channel;
415       int32_t col_count = height * width * depth;
416       Tensor temp_grad_outputs;
417       // For 'NCHW' format, we perform reduction twice: first HW, then N.
418       TensorShape temp_grad_output_shape{row_count, col_count};
419       OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
420                                                      temp_grad_output_shape,
421                                                      &temp_grad_outputs));
422       BiasGradGPU<T>::DoRowReduction(
423           context, temp_grad_outputs.flat<T>().data(),
424           output_backprop.template flat<T>().data(), row_count, col_count);
425 
426       row_count = batch;
427       col_count = channel;
428       BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
429                                      temp_grad_outputs.flat<T>().data(),
430                                      row_count, col_count);
431     } else {
432       // For 'NHWC', we simply apply reduction once on NHW.
433       int32_t row_count = batch * height * width * depth;
434       int32_t col_count = channel;
435       BiasGradGPU<T>::DoColReduction(
436           context, const_cast<T*>(output->flat<T>().data()),
437           reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
438           row_count, col_count);
439     }
440   }
441 
Compute(OpKernelContext * context)442   void Compute(OpKernelContext* context) override {
443     const Tensor& output_backprop = context->input(0);
444 
445     OP_REQUIRES(context,
446                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
447                 errors::InvalidArgument("Input tensor must be at least 2D: ",
448                                         output_backprop.shape().DebugString()));
449     int32_t batch, height, width, depth, channel;
450     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
451                      &depth, &channel);
452     Tensor* output = nullptr;
453     TensorShape output_shape{channel};
454     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
455     if (channel == 0) return;
456     auto* stream = context->op_device_context()->stream();
457     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
458     se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
459                                     output->NumElements() * sizeof(T));
460     stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
461     if (output_backprop.NumElements() <= 0) return;
462     if (OpDeterminismRequired()) {
463       // ComputeWithReduceSum is the only deterministic algorithm.
464       ComputeWithReduceSum(context, output_backprop, batch, width, height,
465                            depth, channel, output);
466       return;
467     }
468 
469     int device_id = stream->parent()->device_ordinal();
470     DataType dtype = output_backprop.dtype();
471     BiasAddParams bias_parameters = {
472         {batch, height * width * depth, channel},
473         data_format_,
474         dtype,
475         device_id,
476     };
477 
478     // Autotune two algorithm: customized
479     BiasAddGradGPUConfig algo_config;
480     if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
481       profiler::ScopedAnnotation trace("bias_grad_autotuning");
482 
483       BiasGradGPUProfileResult best_result;
484       // Initialize the timer.
485       perftools::gputools::Timer timer(stream->parent());
486       stream->InitTimer(&timer);
487       stream->ThenStartTimer(&timer);
488       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
489                               depth, channel, output);
490       stream->ThenStopTimer(&timer);
491       uint64 elapsed_microseconds = timer.Microseconds();
492       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
493               << " Native algo latency: " << elapsed_microseconds;
494       if (elapsed_microseconds < best_result.elapsed_time()) {
495         best_result.set_algorithm(BiasAddGradGPUMode::kNative);
496         best_result.set_elapsed_time(elapsed_microseconds);
497       }
498 
499       // Try reduction and profile.
500       stream->ThenStartTimer(&timer);
501       ComputeWithReduceSum(context, output_backprop, batch, width, height,
502                            depth, channel, output);
503       stream->ThenStopTimer(&timer);
504 
505       elapsed_microseconds = timer.Microseconds();
506       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
507               << " Reduction algo latency: " << elapsed_microseconds;
508       if (elapsed_microseconds < best_result.elapsed_time()) {
509         best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
510         best_result.set_elapsed_time(elapsed_microseconds);
511       }
512 
513       algo_config.set_mode(best_result.algorithm());
514       AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
515 
516       // Results are already available during autotune, so no need to continue.
517       return;
518     }
519 
520     // Choose the best algorithm based on autotune results.
521     if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
522       ComputeWithReduceSum(context, output_backprop, batch, width, height,
523                            depth, channel, output);
524     } else {
525       // Default to the customized kernel.
526       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
527                               depth, channel, output);
528     }
529   }
530 
531  private:
532   TensorFormat data_format_;
533 };
534 
535 // Registration of the GPU implementations.
536 #define REGISTER_GPU_KERNEL(type)                                       \
537   REGISTER_KERNEL_BUILDER(                                              \
538       Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
539       BiasGradOp<GPUDevice, type>);
540 
541 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
542 #undef REGISTER_GPU_KERNEL
543 
544 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
545 
546 }  // namespace tensorflow
547