• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // LRN = Local Response Normalization
17 // See docs in ../ops/nn_ops.cc.
18 
19 #define EIGEN_USE_THREADS
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 
29 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
30 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
31 #endif
32 
33 #if !defined(IS_MOBILE_PLATFORM)
34 #include "tensorflow/core/util/work_sharder.h"
35 #endif
36 
37 #if GOOGLE_CUDA
38 #include "cuda/include/cuda.h"
39 #include "tensorflow/core/platform/stream_executor.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 #endif  // GOOGLE_CUDA
42 
43 namespace tensorflow {
44 
45 namespace {
46 
47 // When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
48 // LRN is faster than the main band matrix approach used
49 // below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
50 const int kSingleThreadedLRNDepthCutoff = 384;
51 
52 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
53 // depth_radius + 1) around the diagonal.
54 template <typename T>
GetBandMatrix(int depth,int depth_radius,Eigen::Tensor<T,2,Eigen::RowMajor> * result)55 void GetBandMatrix(int depth, int depth_radius,
56                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
57   result->setZero();
58   for (int row = 0; row < depth; ++row) {
59     const int begin = std::max<int>(0, row - depth_radius);
60     const int end = std::min<int>(depth, row + depth_radius + 1);
61     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
62     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
63     result->slice(start, sizes).setConstant(T(1));
64   }
65 }
66 
67 }  // namespace
68 
69 typedef Eigen::ThreadPoolDevice CPUDevice;
70 typedef Eigen::GpuDevice GPUDevice;
71 
72 template <typename Device, typename T>
73 struct LaunchLRN;
74 
75 template <typename T>
76 struct LaunchLRN<CPUDevice, T> {
LaunchLRNtensorflow::LaunchLRN77   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
78       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
79 
launchtensorflow::LaunchLRN80   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
81               Tensor* output) {
82     const int batch = static_cast<int>(in.dim_size(0));
83     const int rows = static_cast<int>(in.dim_size(1));
84     const int cols = static_cast<int>(in.dim_size(2));
85     const int depth = static_cast<int>(in.dim_size(3));
86 
87 #if defined(IS_MOBILE_PLATFORM)
88     SingleThreadedLRN(in, batch, rows, cols, depth, output);
89 #else
90     const int nodes = cols * rows;
91     if (depth > kSingleThreadedLRNDepthCutoff &&
92         (beta_ == T(0.5) || beta_ == T(1))) {
93       SingleThreadedLRN(in, batch, rows, cols, depth, output);
94       return;
95     }
96 
97     auto in_shaped = in.shaped<T, 2>({nodes * batch, depth});
98 
99     // Multiplying the input with the band matrix has the effect of reducing the
100     // correct patch along the depth.
101     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
102     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
103 
104     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
105     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
106     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
107     if (beta_ == T(1)) {
108       out_shaped.device(context->eigen_cpu_device()) =
109           in_shaped * tmp.inverse();
110     } else if (beta_ == T(0.5)) {
111       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
112     } else {
113       out_shaped.device(context->eigen_cpu_device()) =
114           in_shaped * (tmp.log() * -beta_).exp();
115     }
116 #endif
117   }
118 
119  private:
120   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
121 
SingleThreadedLRNtensorflow::LaunchLRN122   void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
123                          const int cols, const int depth, Tensor* out) {
124     Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_in(
125         in.flat<T>().data(), depth, batch * rows * cols);
126 
127     Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_out(
128         out->flat<T>().data(), depth, batch * rows * cols);
129 
130     const int double_depth_radius = depth_radius_ * 2;
131     Eigen::Matrix<T, Eigen::Dynamic, 1> padded_square(data_in.rows() +
132                                                       double_depth_radius);
133     padded_square.setZero();
134     for (int r = 0; r < data_in.cols(); ++r) {
135       // Do local response normalization for data_in(:, r). First, compute the
136       // square and store them in buffer for repeated use.
137       padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
138           data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
139       // Then, compute the scale and write it to data_out.
140       T accumulated_scale(0);
141       for (int i = 0; i < double_depth_radius; ++i) {
142         accumulated_scale += padded_square(i);
143       }
144       for (int i = 0; i < data_in.rows(); ++i) {
145         accumulated_scale += padded_square(i + double_depth_radius);
146         data_out(i, r) = bias_ + accumulated_scale;
147         accumulated_scale -= padded_square(i);
148       }
149     }
150 
151     if (beta_ == T(1)) {
152       data_out.array() = data_in.array() * data_out.array().inverse();
153     } else if (beta_ == T(0.5)) {
154       data_out.array() = data_in.array() * data_out.array().rsqrt();
155     } else {
156       data_out.array() =
157           data_in.array() * (data_out.array().log() * -beta_).exp();
158     }
159   }
160 
161   int depth_radius_;
162   T bias_;
163   T alpha_;
164   T beta_;
165 };
166 
167 #if GOOGLE_CUDA
168 
169 template <typename T>
170 struct LaunchLRN<GPUDevice, T> {
LaunchLRNtensorflow::LaunchLRN171   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
172       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
173 
launchtensorflow::LaunchLRN174   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
175               Tensor* output) {
176     OP_REQUIRES(
177         context, beta_ >= 0.01,
178         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
179 
180     OP_REQUIRES(
181         context, depth_radius_ > 0 && depth_radius_ <= 7,
182         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
183                                 depth_radius_));
184     OP_REQUIRES(
185         context, bias_ >= 1e-5,
186         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
187 
188     // Cast to platform-specific int to avoid conversion warnings.
189     const int batch = static_cast<int>(in.dim_size(0));
190     const int rows = static_cast<int>(in.dim_size(1));
191     const int cols = static_cast<int>(in.dim_size(2));
192     const int depth = static_cast<int>(in.dim_size(3));
193 
194     se::dnn::BatchDescriptor dimensions_desc;
195     dimensions_desc.set_count(batch)
196         .set_height(rows)
197         .set_width(cols)
198         .set_feature_map_count(depth)
199         .set_layout(se::dnn::DataLayout::kBatchYXDepth);
200 
201     se::dnn::NormalizeDescriptor normalize_desc;
202     normalize_desc.set_bias(bias_)
203         .set_range(depth_radius_)
204         .set_alpha(alpha_)
205         .set_beta(beta_);
206 
207     auto input_data = StreamExecutorUtil::AsDeviceMemory<T>(in);
208     auto output_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
209 
210     auto* stream = context->op_device_context()->stream();
211     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
212 
213     bool status =
214         stream
215             ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
216                                           input_data, &output_data)
217             .ok();
218     OP_REQUIRES(context, status,
219                 errors::Internal("NormalizeWithDimensions launch failed"));
220   }
221 
222   int depth_radius_;
223   T bias_;
224   T alpha_;
225   T beta_;
226 };
227 
228 #endif  // GOOGLE_CUDA
229 
230 template <typename Device, typename T>
231 class LRNOp : public OpKernel {
232  public:
LRNOp(OpKernelConstruction * context)233   explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
234     int64 depth_radius64;
235     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
236     OP_REQUIRES(
237         context,
238         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
239         errors::InvalidArgument("depth_radius = ", depth_radius64,
240                                 " larger than int max"));
241     depth_radius_ = static_cast<int>(depth_radius64);
242     float tmp;
243     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
244     bias_ = T(tmp);
245     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
246     alpha_ = T(tmp);
247     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
248     beta_ = T(tmp);
249   }
250 
Compute(OpKernelContext * context)251   void Compute(OpKernelContext* context) override {
252     const Tensor& in = context->input(0);
253     OP_REQUIRES(context, in.dims() == 4,
254                 errors::InvalidArgument("in must be 4-dimensional"));
255     OP_REQUIRES(
256         context,
257         FastBoundsCheck(in.NumElements(), std::numeric_limits<int>::max()),
258         errors::InvalidArgument("argument to LRN too large"));
259     // Cast to platform-specific int to avoid conversion warnings.
260     const int batch = static_cast<int>(in.dim_size(0));
261     const int rows = static_cast<int>(in.dim_size(1));
262     const int cols = static_cast<int>(in.dim_size(2));
263     const int depth = static_cast<int>(in.dim_size(3));
264 
265     OP_REQUIRES(context,
266                 (depth + depth_radius_) <= std::numeric_limits<int>::max(),
267                 errors::InvalidArgument("depth ", depth, " + depth_radius ",
268                                         depth_radius_, " exceeds int max."));
269 
270     Tensor* output = nullptr;
271     OP_REQUIRES_OK(context,
272                    context->allocate_output(
273                        0, TensorShape({batch, rows, cols, depth}), &output));
274 
275     LaunchLRN<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
276     launcher.launch(context, this, in, output);
277   }
278 
279  private:
280   int depth_radius_;
281   T bias_;
282   T alpha_;
283   T beta_;
284 };
285 
286 #define REGISTER_CPU(T)                                      \
287   REGISTER_KERNEL_BUILDER(                                   \
288       Name("LRN").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
289       LRNOp<CPUDevice, T>);
290 TF_CALL_float(REGISTER_CPU);
291 TF_CALL_half(REGISTER_CPU);
292 
293 #undef REGISTER_CPU
294 
295 #if GOOGLE_CUDA
296 
297 #define REGISTER_GPU(T)                                      \
298   REGISTER_KERNEL_BUILDER(                                   \
299       Name("LRN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
300       LRNOp<GPUDevice, T>);
301 TF_CALL_float(REGISTER_GPU);
302 
303 #undef REGISTER_GPU
304 
305 #endif  // GOOGLE_CUDA
306 
307 #if !defined(IS_MOBILE_PLATFORM)
308 
309 template <typename Device, typename T>
310 struct LaunchLRNGrad;
311 
312 template <typename T>
313 struct LaunchLRNGrad<CPUDevice, T> {
LaunchLRNGradtensorflow::LaunchLRNGrad314   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
315       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
316 
launchtensorflow::LaunchLRNGrad317   void launch(OpKernelContext* context, OpKernel* kernel,
318               const Tensor& in_grads, const Tensor& in_image,
319               const Tensor& out_image, Tensor* output) {
320     const int64 batch = in_grads.dim_size(0);
321     const int64 rows = in_grads.dim_size(1);
322     const int64 cols = in_grads.dim_size(2);
323     const int64 depth = in_grads.dim_size(3);
324     const auto nodes = cols * rows;
325     auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
326     auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
327     auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
328 
329     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
330     out_shaped.setZero();
331 
332     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
333                   depth](int64 begin, int64 end) {
334       for (int64 i = begin; i < end; ++i) {
335         for (int64 j = 0; j < depth; ++j) {
336           // Let y be the LRN activations and x be the inputs along the depth
337           // dimension. (LRN operates independently along rows, cols, and
338           // batch).
339           // We have
340           // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
341           //      x_j^2))^beta
342           //
343           // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
344           //           x_j^2))
345           // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta)
346           // dy_i/dx_j = (       - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta)
347           //
348           // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta).
349           // However, this is numerically unstable for small values of xi. We
350           // compute N explicitly here to avoid that.
351 
352           int64 depth_begin = std::max<int64>(0, j - depth_radius_);
353           int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
354 
355           T norm(0);
356           for (int64 k = depth_begin; k < depth_end; ++k) {
357             norm += in_shaped(i, k) * in_shaped(i, k);
358           }
359           norm = alpha_ * norm + bias_;
360           DCHECK_GT(norm, T(1e-6));
361           for (int64 k = depth_begin; k < depth_end; ++k) {
362             T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
363                     activations(i, j) / norm;
364             if (k == j) {
365               dyi += Eigen::numext::pow(norm, -beta_);
366             }
367             dyi *= grads_shaped(i, j);
368             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
369           }
370         }
371       }
372     };
373     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
374     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
375           depth * depth, shard);
376   }
377 
378   int depth_radius_;
379   T bias_;
380   T alpha_;
381   T beta_;
382 };
383 
384 #if GOOGLE_CUDA
385 
386 template <typename T>
387 struct LaunchLRNGrad<GPUDevice, T> {
LaunchLRNGradtensorflow::LaunchLRNGrad388   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
389       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
390 
launchtensorflow::LaunchLRNGrad391   void launch(OpKernelContext* context, OpKernel* kernel,
392               const Tensor& in_grads, const Tensor& in_image,
393               const Tensor& out_image, Tensor* output) {
394     OP_REQUIRES(
395         context, beta_ >= 0.01,
396         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
397 
398     OP_REQUIRES(
399         context, depth_radius_ > 0 && depth_radius_ <= 7,
400         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
401                                 depth_radius_));
402     OP_REQUIRES(
403         context, bias_ >= 1e-5,
404         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
405 
406     const int64 batch = in_grads.dim_size(0);
407     const int64 rows = in_grads.dim_size(1);
408     const int64 cols = in_grads.dim_size(2);
409     const int64 depth = in_grads.dim_size(3);
410 
411     se::dnn::BatchDescriptor dimensions_desc;
412     dimensions_desc.set_count(batch)
413         .set_height(rows)
414         .set_width(cols)
415         .set_feature_map_count(depth)
416         .set_layout(se::dnn::DataLayout::kBatchYXDepth);
417 
418     se::dnn::NormalizeDescriptor normalize_desc;
419     normalize_desc.set_bias(bias_)
420         .set_range(depth_radius_)
421         .set_alpha(alpha_)
422         .set_beta(beta_);
423 
424     auto input_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(in_grads);
425     auto input_image_data = StreamExecutorUtil::AsDeviceMemory<T>(in_image);
426     auto output_image_data = StreamExecutorUtil::AsDeviceMemory<T>(out_image);
427     auto output_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
428 
429     auto* stream = context->op_device_context()->stream();
430     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
431 
432     bool status =
433         stream
434             ->ThenNormalizeBackwardWithDimensions(
435                 normalize_desc, dimensions_desc, input_image_data,
436                 output_image_data, input_grads_data, &output_grads_data)
437             .ok();
438     OP_REQUIRES(
439         context, status,
440         errors::Internal("NormalizeBackwardWithDimensions launch failed"));
441   }
442 
443   int depth_radius_;
444   T bias_;
445   T alpha_;
446   T beta_;
447 };
448 
449 #endif  // GOOGLE_CUDA
450 
451 template <typename Device, typename T>
452 class LRNGradOp : public OpKernel {
453  public:
LRNGradOp(OpKernelConstruction * context)454   explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
455     int64 depth_radius64;
456     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
457     OP_REQUIRES(
458         context,
459         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
460         errors::InvalidArgument("depth_radius = ", depth_radius64,
461                                 " larger than int max"));
462     depth_radius_ = static_cast<int>(depth_radius64);
463     float tmp;
464     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
465     bias_ = T(tmp);
466     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
467     alpha_ = T(tmp);
468     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
469     beta_ = T(tmp);
470   }
471 
Compute(OpKernelContext * context)472   void Compute(OpKernelContext* context) override {
473     const Tensor& in_grads = context->input(0);
474     const Tensor& in_image = context->input(1);
475     const Tensor& out_image = context->input(2);
476 
477     OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4,
478                 errors::InvalidArgument("inputs must be 4-dimensional"));
479     const int64 batch = in_grads.dim_size(0);
480     const int64 rows = in_grads.dim_size(1);
481     const int64 cols = in_grads.dim_size(2);
482     const int64 depth = in_grads.dim_size(3);
483     OP_REQUIRES(
484         context,
485         in_image.dim_size(0) == batch && in_image.dim_size(1) == rows &&
486             in_image.dim_size(2) == cols && in_image.dim_size(3) == depth &&
487             out_image.dim_size(0) == batch && out_image.dim_size(1) == rows &&
488             out_image.dim_size(2) == cols && out_image.dim_size(3) == depth,
489         errors::InvalidArgument(
490             "input_grads, input_image, and out_image should have the same "
491             "shape"));
492 
493     Tensor* output = nullptr;
494     OP_REQUIRES_OK(context,
495                    context->allocate_output(
496                        0, TensorShape({batch, rows, cols, depth}), &output));
497 
498     LaunchLRNGrad<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
499     launcher.launch(context, this, in_grads, in_image, out_image, output);
500   }
501 
502  private:
503   int depth_radius_;
504   T bias_;
505   T alpha_;
506   T beta_;
507 };
508 
509 #define REGISTER_CPU(T)                                          \
510   REGISTER_KERNEL_BUILDER(                                       \
511       Name("LRNGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
512       LRNGradOp<CPUDevice, T>);
513 TF_CALL_float(REGISTER_CPU);
514 TF_CALL_half(REGISTER_CPU);
515 
516 #undef REGISTER_CPU
517 
518 #if GOOGLE_CUDA
519 
520 #define REGISTER_GPU(T)                                          \
521   REGISTER_KERNEL_BUILDER(                                       \
522       Name("LRNGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
523       LRNGradOp<GPUDevice, T>);
524 TF_CALL_float(REGISTER_GPU);
525 
526 #undef REGISTER_GPU
527 
528 #endif  // GOOGLE_CUDA
529 
530 #endif  // !defined(IS_MOBILE_PLATFORM)
531 
532 }  // namespace tensorflow
533