• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 namespace {
21 
mish_kernel(TensorIteratorBase & iter)22 void mish_kernel(TensorIteratorBase& iter) {
23   AT_DISPATCH_FLOATING_TYPES_AND2(
24       at::ScalarType::Half,
25       at::ScalarType::BFloat16,
26       iter.dtype(),
27       "mish_cuda",
28       [&]() {
29         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
30           using opmath_t = at::opmath_type<scalar_t>;
31           const opmath_t x_acc = static_cast<opmath_t>(x);
32           return x_acc *
33               c10::cuda::compat::tanh(
34                      c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
35         });
36       });
37 }
38 
mish_backward_kernel(TensorIterator & iter)39 void mish_backward_kernel(TensorIterator& iter) {
40   AT_DISPATCH_FLOATING_TYPES_AND2(
41       at::ScalarType::Half,
42       at::ScalarType::BFloat16,
43       iter.dtype(),
44       "mish_backward_cuda",
45       [&]() {
46         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
47           using opmath_t = at::opmath_type<scalar_t>;
48           const opmath_t dy_acc = static_cast<opmath_t>(dy);
49           const opmath_t x_acc = static_cast<opmath_t>(x);
50           const opmath_t s_acc =
51               opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
52           const opmath_t t_acc = c10::cuda::compat::tanh(
53               c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
54           return dy_acc *
55               (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc));
56         });
57       });
58 }
59 } // namespace
60 
61 REGISTER_DISPATCH(mish_stub, &mish_kernel);
62 REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel);
63 
64 } // namespace at::native
65