#define TORCH_ASSERT_NO_OPERATORS #include #include #include #include #include #include #include // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. namespace at::native { void logaddexp_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2( ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "logaddexp_cuda", [&]() { using opmath_t = at::opmath_type; gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { const auto a = static_cast(a_); const auto b = static_cast(b_); if (::isinf(a) && a == b) { return a; } else { const auto m = ::max(a, b); return m + ::log1p(::exp(-::abs(a - b))); } }); }); } void logaddexp2_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2( ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "logaddexp2_cuda", [&]() { using opmath_t = at::opmath_type; const auto inv_log_2 = static_cast(1.0 / c10::ln_2); gpu_kernel(iter, [inv_log_2] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { const auto a = static_cast(a_); const auto b = static_cast(b_); if (::isinf(a) && a == b) { return a; } else { const auto m = ::max(a, b); return m + ::log1p(::exp2(-::abs(a - b))) * inv_log_2; } }); }); } REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda); REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda); } // namespace at::native