// Copyright 2015 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // multi_thread_gemm.h: Entry point to the multithreaded version of the // generated (meta) gemm library. #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_ #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_ #ifdef GEMMLOWP_NEON_32 #include "multi_thread_common.h" #include "single_thread_gemm.h" namespace gemmlowp { namespace meta { namespace internal { const std::int32_t kMaxCacheFriendlySize = 24 * 1024; template void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs, std::int32_t m, std::int32_t n, std::int32_t k, OUT_TYPE* result, std::int32_t result_stride, const F& operation) { const std::int32_t rhs_size = n * k * sizeof(IN_TYPE); if (rhs_size > kMaxCacheFriendlySize) { const std::int32_t optimal_n = std::max(1, 3 * (kMaxCacheFriendlySize / (k * 3))); const std::int32_t chunks_count_less_one = n / optimal_n - 1; const std::int32_t chunk_size = optimal_n * k; for (int i = 0; i < chunks_count_less_one; ++i) { operation.ExecuteCacheFriendlyMatrixMatrix( scratch, lhs, rhs + i * chunk_size, m, optimal_n, k, result + i * optimal_n, result_stride); } const std::int32_t n_left = n - chunks_count_less_one * optimal_n; operation.ExecuteCacheFriendlyMatrixMatrix( scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k, result + chunks_count_less_one * optimal_n, result_stride); } else { operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride); } } class GemmQuantized8BitOperation { public: GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t sum_offset, std::int32_t multiplier, std::int32_t shift) : lhs_offset(lhs_offset), rhs_offset(rhs_offset), sum_offset(sum_offset), multiplier(multiplier), shift(shift) {} void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::uint8_t* result, std::int32_t result_stride) const { CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, *this); } void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::uint8_t* result, std::int32_t result_stride) const { gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, multiplier, shift, result, result_stride); } static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, std::int32_t k) { return 128 * 1024; } private: std::int32_t lhs_offset; std::int32_t rhs_offset; std::int32_t sum_offset; std::int32_t multiplier; std::int32_t shift; }; class GemmFloatOperation { public: GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, float result_offset) : lhs_offset(lhs_offset), rhs_offset(rhs_offset), result_offset(result_offset) {} void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, float* result, std::int32_t result_stride) const { CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, *this); } void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, float* result, std::int32_t result_stride) const { gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result, result_stride); } static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, std::int32_t k) { return 128 * 1024; } private: std::int32_t lhs_offset; std::int32_t rhs_offset; float result_offset; }; class GemmInt32Operation { public: GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset) : lhs_offset(lhs_offset), rhs_offset(rhs_offset) {} void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t* result, std::int32_t result_stride) const { CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, *this); } void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t* result, std::int32_t result_stride) const { gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result, result_stride); } static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, std::int32_t k) { return 128 * 1024; } private: std::int32_t lhs_offset; std::int32_t rhs_offset; }; } // namespace internal std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t max_threads) { return internal::ResolveMaxThreads(max_threads) * internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k); } void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads, std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t sum_offset, std::int32_t multiplier, std::int32_t shift, std::uint8_t* result) { internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset, sum_offset, multiplier, shift); internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, n, k, result, n, operation); } std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t max_threads) { return internal::ResolveMaxThreads(max_threads) * internal::GemmFloatOperation::ScratchPerThread(m, n, k); } void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads, std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, float result_offset, float* result) { internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset); internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, n, k, result, n, operation); } std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t max_threads) { return internal::ResolveMaxThreads(max_threads) * internal::GemmInt32Operation::ScratchPerThread(m, n, k); } void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool, std::int32_t max_threads, std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t* result) { internal::GemmInt32Operation operation(lhs_offset, rhs_offset); internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, n, k, result, n, operation); } } // namespace meta } // namespace gemmlowp #else #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!" #endif #endif // GEMMLOWP_META_MULTI_THREAD_GEMM_H_