// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ #include #include "base.h" namespace gemmlowp { namespace meta { template void Gemm(const Params& params); class GemmExecutorPackRHS { public: template static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, int kernel_k) { const int lhs_scratch = StreamUtil::Scratch( params.left_stream, kernel_m, kernel_k); const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n); const int rhs_scratch = rhs_chunks * StreamUtil::Scratch( params.right_stream, kernel_n, kernel_k); return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); } template static void ExecuteDispatch3D(const P& params) { // Shorthand typedefs for streams and multiply kernels. typedef typename P::InType InType; typedef typename P::OutType OutType; typedef Stream LeftStreamF; typedef Stream LeftStreamL; typedef Stream RightStreamF; typedef Stream RightStreamL; typedef Stream OutputStreamFF; typedef Stream OutputStreamLF; typedef MulKernel KernelFF; typedef MulKernel KernelFL; typedef MulKernel KernelLF; typedef MulKernel KernelLL; #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" << k_leftovers << " -- " << params.m << "x" << params.n << "x" << params.k << std::endl; LeftStreamF::Debug(params.left_stream); LeftStreamL::Debug(params.left_stream); RightStreamF::Debug(params.right_stream); RightStreamL::Debug(params.right_stream); OutputStreamFF::Debug(params.fused_kernel.output_stream); OutputStreamLF::Debug(params.fused_kernel.output_stream); KernelFF::Debug(params.fused_kernel); KernelFL::Debug(params.fused_kernel); KernelLF::Debug(params.fused_kernel); KernelLL::Debug(params.fused_kernel); #endif #endif int lhs_chunks = params.m / m; int rhs_chunks = params.n / n; // Scratch memory for packed LHS & RHS chunks. std::uint8_t* packed_lhs = params.scratch; std::uint8_t* packed_rhs = params.scratch + LeftStreamF::Scratch(params.left_stream); // Pack full RHS first. std::uint8_t* packed_rhs_chunk = packed_rhs; const int packed_rhs_chunk_size = RightStreamF::PackedStride(params.right_stream); { const std::uint8_t* rhs_chunk = reinterpret_cast(params.rhs); const int rhs_chunk_size = RightStreamF::UnpackedStride(params.right_stream); for (int i = 0; i < rhs_chunks; ++i) { RightStreamF::Pack(reinterpret_cast(rhs_chunk), params.right_stream, reinterpret_cast(packed_rhs_chunk)); rhs_chunk += rhs_chunk_size; packed_rhs_chunk += packed_rhs_chunk_size; } RightStreamL::Pack(reinterpret_cast(rhs_chunk), params.right_stream, reinterpret_cast(packed_rhs_chunk)); } // Multiply RHS by LHS one LHS chunk at a time. const std::uint8_t* lhs_chunk = reinterpret_cast(params.lhs); std::uint8_t* result_strip = reinterpret_cast(params.result); std::uint8_t* result_chunk = result_strip; { const int lhs_chunk_size = LeftStreamF::UnpackedStride(params.left_stream); const int result_strip_size = OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); const int result_chunk_size = OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); for (int i = 0; i < lhs_chunks; ++i) { LeftStreamF::Pack(reinterpret_cast(lhs_chunk), params.left_stream, reinterpret_cast(packed_lhs)); result_chunk = result_strip; packed_rhs_chunk = packed_rhs; for (int j = 0; j < rhs_chunks; ++j) { KernelFF::Multiply(reinterpret_cast(packed_lhs), reinterpret_cast(packed_rhs_chunk), params.fused_kernel, reinterpret_cast(result_chunk)); result_chunk += result_chunk_size; packed_rhs_chunk += packed_rhs_chunk_size; } KernelFL::Multiply(reinterpret_cast(packed_lhs), reinterpret_cast(packed_rhs_chunk), params.fused_kernel, reinterpret_cast(result_chunk)); lhs_chunk += lhs_chunk_size; result_strip += result_strip_size; } } // Leftover LHS chunk. if (m_leftovers > 0) { // static if const int result_chunk_size = OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream); LeftStreamL::Pack(reinterpret_cast(lhs_chunk), params.left_stream, reinterpret_cast(packed_lhs)); result_chunk = result_strip; packed_rhs_chunk = packed_rhs; for (int i = 0; i < rhs_chunks; ++i) { KernelLF::Multiply(reinterpret_cast(packed_lhs), reinterpret_cast(packed_rhs_chunk), params.fused_kernel, reinterpret_cast(result_chunk)); result_chunk += result_chunk_size; packed_rhs_chunk += packed_rhs_chunk_size; } KernelLL::Multiply(reinterpret_cast(packed_lhs), reinterpret_cast(packed_rhs_chunk), params.fused_kernel, reinterpret_cast(result_chunk)); } } }; class GemmExecutorPackLHS { public: template static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, int kernel_k) { const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m); const int lhs_scratch = lhs_chunks * StreamUtil::Scratch( params.left_stream, kernel_m, kernel_k); const int rhs_scratch = StreamUtil::Scratch( params.right_stream, kernel_n, kernel_k); return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); } template static void ExecuteDispatch3D(const P& params) { // Shorthand typedefs for streams and multiply kernels. typedef typename P::InType InType; typedef typename P::OutType OutType; typedef Stream LeftStreamF; typedef Stream LeftStreamL; typedef Stream RightStreamF; typedef Stream RightStreamL; typedef Stream OutputStreamFF; typedef Stream OutputStreamFL; typedef MulKernel KernelFF; typedef MulKernel KernelFL; typedef MulKernel KernelLF; typedef MulKernel KernelLL; #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" << k_leftovers << " -- " << params.m << "x" << params.n << "x" << params.k << std::endl; LeftStreamF::Debug(params.left_stream); LeftStreamL::Debug(params.left_stream); RightStreamF::Debug(params.right_stream); RightStreamL::Debug(params.right_stream); OutputStreamFF::Debug(params.fused_kernel.output_stream); OutputStreamFL::Debug(params.fused_kernel.output_stream); KernelFF::Debug(params.fused_kernel); KernelFL::Debug(params.fused_kernel); KernelLF::Debug(params.fused_kernel); KernelLL::Debug(params.fused_kernel); #endif #endif int lhs_chunks = params.m / m; int rhs_chunks = params.n / n; // Scratch memory for packed LHS & RHS chunks. std::uint8_t* packed_rhs = params.scratch; std::uint8_t* packed_lhs = params.scratch + RightStreamF::Scratch(params.right_stream); // Pack full LHS first. std::uint8_t* packed_lhs_chunk = packed_lhs; const int packed_lhs_chunk_size = LeftStreamF::PackedStride(params.left_stream); { const std::uint8_t* lhs_chunk = reinterpret_cast(params.lhs); const int lhs_chunk_size = LeftStreamF::UnpackedStride(params.left_stream); for (int i = 0; i < lhs_chunks; ++i) { LeftStreamF::Pack(reinterpret_cast(lhs_chunk), params.left_stream, reinterpret_cast(packed_lhs_chunk)); lhs_chunk += lhs_chunk_size; packed_lhs_chunk += packed_lhs_chunk_size; } LeftStreamL::Pack(reinterpret_cast(lhs_chunk), params.left_stream, reinterpret_cast(packed_lhs_chunk)); } // Multiply RHS by LHS one RHS chunk at a time. const std::uint8_t* rhs_chunk = reinterpret_cast(params.rhs); std::uint8_t* result_strip = reinterpret_cast(params.result); std::uint8_t* result_chunk = result_strip; { const int rhs_chunk_size = RightStreamF::UnpackedStride(params.right_stream); const int result_strip_size = OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); const int result_chunk_size = OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); for (int i = 0; i < rhs_chunks; ++i) { RightStreamF::Pack(reinterpret_cast(rhs_chunk), params.right_stream, reinterpret_cast(packed_rhs)); result_chunk = result_strip; packed_lhs_chunk = packed_lhs; for (int j = 0; j < lhs_chunks; ++j) { KernelFF::Multiply(reinterpret_cast(packed_lhs_chunk), reinterpret_cast(packed_rhs), params.fused_kernel, reinterpret_cast(result_chunk)); result_chunk += result_chunk_size; packed_lhs_chunk += packed_lhs_chunk_size; } KernelLF::Multiply(reinterpret_cast(packed_lhs_chunk), reinterpret_cast(packed_rhs), params.fused_kernel, reinterpret_cast(result_chunk)); rhs_chunk += rhs_chunk_size; result_strip += result_strip_size; } } // Leftover RHS chunk. if (n_leftovers > 0) { // static if const int result_chunk_size = OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream); RightStreamL::Pack(reinterpret_cast(rhs_chunk), params.right_stream, reinterpret_cast(packed_rhs)); result_chunk = result_strip; packed_lhs_chunk = packed_lhs; for (int i = 0; i < lhs_chunks; ++i) { KernelFL::Multiply(reinterpret_cast(packed_lhs_chunk), reinterpret_cast(packed_rhs), params.fused_kernel, reinterpret_cast(result_chunk)); result_chunk += result_chunk_size; packed_lhs_chunk += packed_lhs_chunk_size; } KernelLL::Multiply(reinterpret_cast(packed_lhs_chunk), reinterpret_cast(packed_rhs), params.fused_kernel, reinterpret_cast(result_chunk)); } } }; namespace internal { inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory, int per_chunk_memory, int total_dim, int chunk_dim) { assert(constant_memory + per_chunk_memory < cache_size); const int available_cache = cache_size - constant_memory; const int available_chunks = available_cache / per_chunk_memory; const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim; return (chunks_count + available_chunks - 1) / available_chunks; } template inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n, const Params& params, Params* task_params) { task_params->m = m; task_params->lhs = StreamUtil::Offset( params.left_stream, params.lhs, m_offset, 0); task_params->n = n; task_params->rhs = StreamUtil::Offset( params.right_stream, params.rhs, n_offset, 0); task_params->result = StreamUtil:: Offset(params.fused_kernel.output_stream, params.result, m_offset, n_offset); } } // namespace internal template class GemmExecutorPackRHSCacheFriendly { public: template static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, int kernel_k) { return cache_size; } template static void ExecuteDispatch3D(const P& params) { typedef Stream LeftStream; typedef Stream RightStream; const int lhs_scratch = LeftStream::Scratch(params.left_stream); const int rhs_scratch = RightStream::Scratch(params.right_stream); const int cache_friendly_tasks_count = internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch, rhs_scratch, params.n, n); if (cache_friendly_tasks_count == 1) { GemmExecutorPackRHS::ExecuteDispatch3D(params); return; } const int cache_friendly_dim = params.n / cache_friendly_tasks_count; P task_params = params; for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim, cache_friendly_dim, params, &task_params); Gemm(task_params); } const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum, params, &task_params); Gemm(task_params); } }; template class GemmExecutorPackLHSCacheFriendly { public: template static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, int kernel_k) { return cache_size; } template static void ExecuteDispatch3D(const P& params) { typedef Stream LeftStream; typedef Stream RightStream; const int lhs_scratch = LeftStream::Scratch(params.left_stream); const int rhs_scratch = RightStream::Scratch(params.right_stream); const int cache_friendly_tasks_count = internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch, lhs_scratch, params.m, m); if (cache_friendly_tasks_count == 1) { GemmExecutorPackLHS::ExecuteDispatch3D(params); return; } const int cache_friendly_dim = params.m / cache_friendly_tasks_count; P task_params = params; for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { internal::UpdateCacheFriendlyTask(i * cache_friendly_dim, cache_friendly_dim, 0, params.n, params, &task_params); Gemm(task_params); } const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n, params, &task_params); Gemm(task_params); } }; namespace internal { // Stage 3. template struct Dispatch3DStage3 { static void Execute(const P& params, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << fixed_m << "x" << fixed_n << "x" << variable_k << std::endl << std::flush; #endif #endif if (k == variable_k) { E::template ExecuteDispatch3D(params); } else { Dispatch3DStage3::Execute(params, k); } } }; template struct Dispatch3DStage3 { static void Execute(const P& params, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl << std::flush; #endif #endif if (k == 0) { E::template ExecuteDispatch3D(params); } else { std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases." << std::endl << std::flush; std::exit(1); } } }; // Stage 2. template struct Dispatch3DStage2 { static void Execute(const P& params, int n, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << fixed_m << "x" << variable_n << std::endl << std::flush; #endif #endif if (n == variable_n) { Dispatch3DStage3::Execute(params, k); } else { Dispatch3DStage2::Execute(params, n, k); } } }; template struct Dispatch3DStage2 { static void Execute(const P& params, int n, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << fixed_m << "x" << 0 << std::endl << std::flush; #endif #endif if (n == 0) { Dispatch3DStage3::Execute(params, k); } else { std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases." << std::endl << std::flush; std::exit(1); } } }; // Stage 1. template struct Dispatch3DStage1 { static void Execute(const P& params, int m, int n, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << variable_m << std::endl << std::flush; #endif #endif if (m == variable_m) { Dispatch3DStage2::Execute(params, n, k); } else { Dispatch3DStage1::Execute( params, m, n, k); } } }; template struct Dispatch3DStage1 { static void Execute(const P& params, int m, int n, int k) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_VERBOSE std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k << " : " << 0 << std::endl << std::flush; #endif #endif if (m == 0) { Dispatch3DStage2::Execute(params, n, k); } else { std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases." << std::endl << std::flush; std::exit(1); } } }; } // namespace internal template inline void Gemm(const Params& params) { internal::Dispatch3DStage1::Execute(params, params.m % kernel_m, params.n % kernel_n, params.k % kernel_k); } } // namespace meta } // namespace gemmlowp #endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_