// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_ #define GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_ #include "../internal/common.h" #ifdef GEMMLOWP_NEON #include "quantized_mul_kernels.h" #include "single_thread_gemm.h" #include "streams.h" namespace gemmlowp { namespace meta { void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t result_offset, std::int32_t multiplicative_offset, std::int32_t shift, std::uint8_t* result, std::int32_t result_stride) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemmQ8." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = m; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = result_offset + k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset; params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1)); params.fused_kernel.kernel.shift = -shift; params.fused_kernel.kernel.count = k; params.fused_kernel.output_stream.stride = result_stride; Gemm(params); } void gemv_q8(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t result_offset, std::int32_t multiplicative_offset, std::int32_t shift, std::uint8_t* result) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemvQ8." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = 1; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = result_offset + k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset; params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1)); params.fused_kernel.kernel.shift = -shift; params.fused_kernel.kernel.count = k; params.fused_kernel.output_stream.stride = n; if (k < 1536) { Gemm(params); } else { Gemm(params); } } void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t* result, std::int32_t result_stride) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemmI32." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = m; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.count = k; params.fused_kernel.output_stream.stride = result_stride * 4; Gemm(params); } void gemv_i32(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t* result) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemvI32." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = 1; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.count = k; params.fused_kernel.output_stream.stride = 0; if (k < 1664) { Gemm(params); } else { Gemm(params); } } void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t m, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, float result_offset, float* result, std::int32_t result_stride) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemmF." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = m; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.count = k; params.fused_kernel.kernel.scale = result_offset; params.fused_kernel.output_stream.stride = result_stride * 4; Gemm(params); } void gemv_f(std::uint8_t* scratch, const std::uint8_t* lhs, const std::uint8_t* rhs, std::int32_t n, std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset, float result_offset, float* result) { #ifdef DEBUG #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE std::cout << "Legacy::GemvF." << std::endl; #endif #endif typedef GemmParams Params; Params params; params.m = 1; params.n = n; params.k = k; params.lhs = lhs; params.rhs = rhs; params.result = result; params.scratch = scratch; params.left_stream.count = k; params.left_stream.stride = k; params.left_stream.multiplicative_sum_offset = rhs_offset; params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset; params.right_stream.count = k; params.right_stream.stride = k; params.right_stream.multiplicative_sum_offset = lhs_offset; params.right_stream.additive_sum_offset = 0; params.fused_kernel.kernel.count = k; params.fused_kernel.kernel.scale = result_offset; params.fused_kernel.output_stream.stride = 0; if (k < 1664) { Gemm(params); } else { Gemm(params); } } } // namespace meta } // namespace gemmlowp #else #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!" #endif #endif // GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_