#define TORCH_ASSERT_NO_OPERATORS #define _USE_MATH_DEFINES #include #include #include #include #include #include #include #include #include #include #include namespace at::native { void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { if (approximate == GeluType::Tanh) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { using opmath_t = at::opmath_type; constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); constexpr opmath_t kKappa = 0.044715; auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); auto inner = kBeta * (static_cast(x) + kKappa * x_cube); return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); }); }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { using opmath_t = at::opmath_type; constexpr opmath_t kAlpha = M_SQRT1_2; return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); }); }); } } void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { if (approximate == GeluType::Tanh) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { using opmath_t = at::opmath_type; constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); constexpr opmath_t kKappa = 0.044715; auto x_sq = static_cast(x) * static_cast(x); auto x_cube = x_sq * static_cast(x); auto inner = kBeta * (static_cast(x) + kKappa * x_cube); auto tanh_inner = c10::cuda::compat::tanh(inner); auto left = opmath_t(0.5) * static_cast(x); auto right = opmath_t(1) + tanh_inner; auto left_derivative = opmath_t(0.5) * right; auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); auto right_derivative = left * tanh_derivative * inner_derivative; return static_cast(dy) * (left_derivative + right_derivative); }); }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { using opmath_t = at::opmath_type; constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); constexpr opmath_t kAlpha = M_SQRT1_2; const opmath_t cdf = opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); const opmath_t pdf = c10::cuda::compat::exp( opmath_t(-0.5) * static_cast(x) * static_cast(x)) * kBeta; return static_cast(dy) * (cdf + static_cast(x) * pdf); }); }); } } } // namespace at::native