• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <ATen/core/TensorBase.h>
3 #include <optional>
4 
5 
6 namespace at::cuda::detail {
7 TORCH_API void f8f8bf16_rowwise(
8     at::Tensor XQ, // FP8
9     at::Tensor WQ, // FP8
10     at::Tensor x_scale, // FP32
11     at::Tensor w_scale, // FP32
12     std::optional<at::Tensor> bias, // BF16
13     bool use_fast_accum,
14     at::Tensor& out);
15 }  // at::cuda::detail
16