• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Normalization.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 
6 #include <ATen/Dispatch.h>
7 
8 namespace at::native {
9 namespace {
10 
renorm_scale_factor_impl(TensorIteratorBase & iter,double maxnorm)11 void renorm_scale_factor_impl(TensorIteratorBase& iter, double maxnorm) {
12   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "renorm_scale_factor_cpu", [&] {
13     const auto maxnorm_s = static_cast<scalar_t>(maxnorm);
14     gpu_kernel(
15       iter,
16       [maxnorm_s] GPU_LAMBDA (scalar_t norm) -> scalar_t {
17         const auto eps = static_cast<scalar_t>(1e-7);
18         const auto one = static_cast<scalar_t>(1.0);
19         return (norm > maxnorm_s) ?
20             maxnorm_s / (norm + eps) : one;
21       });
22   });
23 }
24 
25 }  // namespace (anonymous)
26 
27 REGISTER_DISPATCH(renorm_scale_factor_stub, &renorm_scale_factor_impl);
28 
29 }  // namespace at::native
30