#include #include #include #include #include // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later // FP32 -> BF16 kernel __global__ void _float_to_bfloat16_cuda_kernel( const float* __restrict__ input, const size_t nrows, const size_t ncols, uint16_t* __restrict__ output) { const auto row_incre = blockDim.y * gridDim.y; const auto col_incre = blockDim.x * gridDim.x; for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { const float* input_row = input + row * ncols; uint16_t* output_row = output + row * ncols; for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { // Add 2^15 and right shift 16 to do round-nearest output_row[col] = (*reinterpret_cast(input_row + col) + (1 << 15)) >> 16; } } } // BF16 -> FP32 kernel __global__ void _bfloat16_to_float_cuda_kernel( const uint16_t* __restrict__ input, const size_t nrows, const size_t ncols, float* __restrict__ output) { const auto row_incre = blockDim.y * gridDim.y; const auto col_incre = blockDim.x * gridDim.x; for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { const uint16_t* input_row = input + row * ncols; float* output_row = output + row * ncols; uint32_t val_fp32 = static_cast( reinterpret_cast(input_row)[col]) << 16; reinterpret_cast(output_row)[col] = val_fp32; } } } namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { TENSOR_ON_CUDA_GPU(input); // Currently it supports 2D inputs TENSOR_NDIM_EQUALS(input, 2); at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); const auto nrows = input.size(0); const auto ncols = input.size(1); const size_t output_columns = ncols; auto output = at::empty( {nrows, ncols}, #if HAS_NCCL_BF16_DATATYPE input.options().dtype(at::kBFloat16)); #else input.options().dtype(at::kHalf)); #endif if (nrows == 0 || ncols == 0) { return output; } constexpr size_t threads_per_block = 256; const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; const auto gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _float_to_bfloat16_cuda_kernel<<< gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( input.const_data_ptr(), nrows, ncols, #if HAS_NCCL_BF16_DATATYPE reinterpret_cast(output.mutable_data_ptr()) #else reinterpret_cast(output.mutable_data_ptr()) #endif ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { TENSOR_ON_CUDA_GPU(input); // Currently it supports 2D inputs TENSOR_NDIM_EQUALS(input, 2); at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); const auto nrows = input.size(0); const auto ncols = input.size(1); const size_t output_columns = ncols; auto output = at::empty( {nrows, ncols}, // 4 = sizeof(float) input.options().dtype(at::kFloat)); // at::kBytes for uint8_t if (nrows == 0 || ncols == 0) { return output; } constexpr size_t threads_per_block = 256; const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; const auto gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _bfloat16_to_float_cuda_kernel<<< gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( #if HAS_NCCL_BF16_DATATYPE reinterpret_cast(input.const_data_ptr()), #else reinterpret_cast(input.const_data_ptr()), #endif nrows, ncols, output.mutable_data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } #define DISPATCH_TO_CUDA(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) TORCH_LIBRARY_IMPL(quantization, CUDA, m) { DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); } } // namespace torch::distributed::c10d::quantization