1 #define TORCH_ASSERT_NO_OPERATORS 2 #include <ATen/native/Normalization.h> 3 #include <ATen/native/TensorIterator.h> 4 #include <ATen/native/cuda/Loops.cuh> 5 6 #include <ATen/Dispatch.h> 7 8 namespace at::native { 9 namespace { 10 renorm_scale_factor_impl(TensorIteratorBase & iter,double maxnorm)11void renorm_scale_factor_impl(TensorIteratorBase& iter, double maxnorm) { 12 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "renorm_scale_factor_cpu", [&] { 13 const auto maxnorm_s = static_cast<scalar_t>(maxnorm); 14 gpu_kernel( 15 iter, 16 [maxnorm_s] GPU_LAMBDA (scalar_t norm) -> scalar_t { 17 const auto eps = static_cast<scalar_t>(1e-7); 18 const auto one = static_cast<scalar_t>(1.0); 19 return (norm > maxnorm_s) ? 20 maxnorm_s / (norm + eps) : one; 21 }); 22 }); 23 } 24 25 } // namespace (anonymous) 26 27 REGISTER_DISPATCH(renorm_scale_factor_stub, &renorm_scale_factor_impl); 28 29 } // namespace at::native 30