#define TORCH_ASSERT_NO_OPERATORS #include #include #include #include #include #include #if defined(__aarch64__) && !defined(C10_MOBILE) #include namespace at::native::blas_impl { void fp16_gemv_notrans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy); void fp16_gemv_trans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy); float fp16_dot_with_fp32_arith( const float16_t* x, const float16_t* a, int64_t len); float bf16_dot_with_fp32_arith( const at::BFloat16* x, const at::BFloat16* a, int64_t len); } #endif namespace at::native { namespace cpublas { namespace { template void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) { if (alpha == opmath_t(1)) { return; // identity } if (alpha == opmath_t(0)) { for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { a[j * lda + i] = scalar_t(0); } } return; } for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { a[j * lda + i] *= alpha; } } } template auto sum(int64_t N, Func f) { constexpr int ilp_factor = 4; using acc_t = decltype(f(0)); // Calculate independent partial sums then add together at the end std::array partial_sums{}; int64_t i = 0; for (; i + ilp_factor <= N; i += ilp_factor) { c10::ForcedUnroll{}([&](int k) { partial_sums[k] += f(i + k); }); } for (; i < N; ++i) { partial_sums[0] += f(i); } for (int k = 1; k < ilp_factor; ++k) { partial_sums[0] += partial_sums[k]; } return partial_sums[0]; } template typename std::enable_if::value, void>::type gemm_notrans_( int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t* a, int64_t lda, const scalar_t* b, int64_t ldb, opmath_t beta, scalar_t* c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); // c += alpha * (a @ b) for (const auto l : c10::irange(k)) { for (const auto j : c10::irange(n)) { opmath_t val = b[l + j * ldb] * alpha; int64_t i_m = m / 4; for (const auto i_i : c10::irange(i_m)) { c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val; c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val; c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val; c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val; } int64_t i = i_m * 4; for (; i < m; i++) c[j * ldc + i] += a[i + l * lda] * val; } } } // std::is_same || std::is_same template typename std::enable_if::value, void>::type gemm_notrans_( int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t* a, int64_t lda, const scalar_t* b, int64_t ldb, opmath_t beta, scalar_t* c, int64_t ldc) { // c += alpha * (a @ b) for (const auto i : c10::irange(m)) { for (const auto j : c10::irange(n)) { const auto dot = sum(k, [&](int64_t l) -> opmath_t { return static_cast(a[l * lda + i]) * static_cast(b[j * ldb + l]); }); if (beta == opmath_t(0)) { c[j * ldc + i] = alpha * dot; } else { c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot; } } } } template void gemm_transa_( TransposeType transa, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, scalar_t *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c const scalar_t *a_ = a; for (const auto i : c10::irange(m)) { const scalar_t *b_ = b; for (const auto j : c10::irange(n)) { const auto dot = sum(k, [&](int64_t l) -> opmath_t { return static_cast(transa == TransposeType::ConjTranspose ? conj_impl(a_[l]) : a_[l]) * static_cast(b_[l]); }); b_ += ldb; if (beta == opmath_t(0)) { c[j*ldc+i] = alpha*dot; } else { c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; } } a_ += lda; } } template void gemm_transb_impl( TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t* a, int64_t lda, const scalar_t* b, int64_t ldb, /* we expect pre-applied beta */ opmath_t* c, int64_t ldc) { // c += alpha * (a @ b.T) for (const auto l : c10::irange(k)) { for (const auto j : c10::irange(n)) { opmath_t val = (transb == TransposeType::ConjTranspose ? conj_impl(b[j + l * ldb]) : b[j + l * ldb]) * alpha; int64_t i_m = m / 4; for (const auto i_i : c10::irange(i_m)) { c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val; c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val; c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val; c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val; } int64_t i = i_m * 4; for (; i < m; i++) c[j * ldc + i] += a[i + l * lda] * val; } } } template typename std::enable_if::value, void>::type gemm_transb_( TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t* a, int64_t lda, const scalar_t* b, int64_t ldb, opmath_t beta, scalar_t* c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c, ldc); } // std::is_same || std::is_same template typename std::enable_if::value, void>::type gemm_transb_( TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t* a, int64_t lda, const scalar_t* b, int64_t ldb, opmath_t beta, scalar_t* c, int64_t ldc) { // We need to calculate full-precision dot products for correctness; // users notice error accumulation with reduced-width types (e.g., // https://github.com/pytorch/pytorch/issues/95125 and // https://github.com/pytorch/pytorch/issues/83863, which were filed // when we used gemm_transb_impl naively, accumulating into // float16/bfloat16). The straightforward way to do this is to use // the vector dot column algorithm anyway, but this gives terrible // performance because of the non-contiguous matrix // access. Therefore, we instead elect to allocate temporary space // to hold the output at higher-precision so that we can accumulate // into it using the above cache-friendly "load one vector element, // FMA it with an entire matrix row into the entire result vector" // algorithm instead. const auto c_size = m * n; auto c_accum = std::make_unique(c_size); if (beta == 1) { for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { c_accum[j * m + i] = c[j * ldc + i]; } } } else if (beta == 0) { for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { c_accum[j * m + i] = 0; } } } else { for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { c_accum[j * m + i] = beta * c[j * ldc + i]; } } } gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c_accum.get(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { c[j * ldc + i] = c_accum[j * m + i]; } } } template void gemm_transab_( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, scalar_t *c, int64_t ldc) { // c = beta * c + alpha * (a.T @ b.T) for (const auto i : c10::irange(m)) { for (const auto j : c10::irange(n)) { const auto dot = sum(k, [&](int64_t l) -> opmath_t { return static_cast(transa == TransposeType::ConjTranspose ? conj_impl(a[i * lda + l]) : a[i * lda + l]) * static_cast(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]); }); if (beta == opmath_t(0)) { c[j * ldc + i] = alpha * dot; } else { c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot; } } } } #if defined(__aarch64__) && !defined(C10_MOBILE) template <> void gemm_notrans_( int64_t m, int64_t n, int64_t k, float alpha, const at::Half* a, int64_t lda, const at::Half* b, int64_t ldb, float beta, at::Half* c, int64_t ldc) { // c += alpha * (a @ b) if (n == 1 && beta == 0.0 && alpha == 1.0) { at::native::blas_impl::fp16_gemv_notrans(m, k, 1.0, reinterpret_cast(a), lda, reinterpret_cast(b), 1, 0.0, reinterpret_cast(c), 1); return; } for (const auto i : c10::irange(m)) { for (const auto j : c10::irange(n)) { const auto dot = sum(k, [&](int64_t l) -> float { return float(c10::detail::fp16_from_bits(a[l * lda + i].x)) * float(c10::detail::fp16_from_bits(b[j * ldb + l].x)); }); if (beta == 0) { c[j * ldc + i] = alpha * dot; } else { c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot; } } } } inline float32x4_t load_as_float32x4(const BFloat16* ptr) { int32x4_t shift = vdupq_n_s32(16); uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); } static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) { return at::native::blas_impl::fp16_dot_with_fp32_arith( reinterpret_cast(a), reinterpret_cast(b), len); } static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) { return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len); } template <> void gemm_transa_( TransposeType transa, int64_t m, int64_t n, int64_t k, float alpha, const at::Half *a, int64_t lda, const at::Half *b, int64_t ldb, float beta, at::Half *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c if (n == 1 && beta == 0.0 && alpha == 1.0) { at::native::blas_impl::fp16_gemv_trans(k, m, 1.0, reinterpret_cast(a), lda, reinterpret_cast(b), 1, 0.0, reinterpret_cast(c), 1); return; } parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { const auto *a_ = a + begin * lda; for (const auto i : c10::irange(begin, end)) { const auto *b_ = b; for (const auto j : c10::irange(n)) { const auto dot = compute_dot(a_, b_, k); b_ += ldb; if (beta == 0) { c[j*ldc+i] = alpha*dot; } else { c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; } } a_ += lda; } }); } template <> void gemm_transa_( TransposeType transa, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16 *a, int64_t lda, const at::BFloat16 *b, int64_t ldb, float beta, at::BFloat16 *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { const auto *a_ = a + begin * lda; for (const auto i : c10::irange(begin, end)) { const auto *b_ = b; for (const auto j : c10::irange(n)) { const auto dot = compute_dot(a_, b_, k); b_ += ldb; if (beta == 0) { c[j*ldc+i] = alpha*dot; } else { c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; } } a_ += lda; } }); } #endif template void gemm_core_( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, scalar_t *c, int64_t ldc) { if (transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else if ( transa != TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else if ( transa == TransposeType::NoTranspose && transb != TransposeType::NoTranspose) { gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else { gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } #if !defined(C10_MOBILE) #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \ TYPE, NAME, __VA_ARGS__) #else #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ kHalf, kBFloat16, \ TYPE, NAME, __VA_ARGS__) #endif void cpublas_gemm_impl( at::ScalarType type, TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, const Scalar& alpha, const void *a, int64_t lda, const void *b, int64_t ldb, const Scalar& beta, void *c, int64_t ldc) { _AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{ using opmath_t = at::opmath_type; gemm_core_( transa, transb, m, n, k, alpha.to(), static_cast(a), lda, static_cast(b), ldb, beta.to(), static_cast(c), ldc); }); } void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){ if (type == at::kBool) { auto a = _a.to(); auto x = static_cast(_x); auto y = static_cast(_y); int64_t i; for(i = 0; i < n; i++) y[i*incy] |= a & x[i*incx]; } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl", [&] { using opmath_t = at::opmath_type; auto a = _a.to(); auto x = static_cast(_x); auto y = static_cast(_y); int64_t i; for(i = 0; i < n; i++) y[i*incy] += a*x[i*incx]; }); } } void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t incx, void *_y, int64_t incy){ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::kComplexHalf, at::kHalf, at::kBFloat16, at::kBool, type, "cpublas_copy_impl", [&] { auto x = static_cast(_x); auto y = static_cast(_y); int64_t i; for(i = 0; i < n; i++) y[i*incy] = x[i*incx]; }); } }} // namespace cpublas::(anonymous) REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl); REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl); REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl); } // namespace at::native