/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #include "ruy/matrix.h" // from @ruy #include "ruy/mul_params.h" // from @ruy #include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { namespace cpu_backend_gemm { namespace detail { inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) { switch (cache_policy) { case CachePolicy::kNeverCache: return ruy::CachePolicy::kNeverCache; case CachePolicy::kCacheIfLargeSpeedup: return ruy::CachePolicy::kCacheIfLargeSpeedup; case CachePolicy::kAlwaysCache: return ruy::CachePolicy::kAlwaysCache; default: TFLITE_DCHECK(false); return ruy::CachePolicy::kNeverCache; } } template void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, ruy::Matrix* dst, bool use_caching = false) { ruy::Order ruy_order = params.order == Order::kColMajor ? ruy::Order::kColMajor : ruy::Order::kRowMajor; ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order, dst->mutable_layout()); // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer. // It does care whether we assign to it a Scalar* or a const Scalar*. dst->set_data(data_ptr); dst->set_zero_point(params.zero_point); if (use_caching) { dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); } } // Floating-point case. template struct MakeRuyMulParamsImpl final { static void Run( const GemmParams& params, ruy::MulParams* ruy_mul_params) { static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint, ""); ruy_mul_params->set_bias(params.bias); ruy_mul_params->set_clamp_min(params.clamp_min); ruy_mul_params->set_clamp_max(params.clamp_max); } }; // Integer-quantized case with destination type narrower than int32 template struct MakeRuyMulParamsImpl final { static void Run( const GemmParams& params, ruy::MulParams* ruy_mul_params) { static_assert(sizeof(DstScalar) < sizeof(std::int32_t), ""); if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier) { ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint); ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent); } if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier) { ruy_mul_params->set_multiplier_fixedpoint_perchannel( params.multiplier_fixedpoint_perchannel); ruy_mul_params->set_multiplier_exponent_perchannel( params.multiplier_exponent_perchannel); } ruy_mul_params->set_bias(params.bias); ruy_mul_params->set_clamp_min(params.clamp_min); ruy_mul_params->set_clamp_max(params.clamp_max); } }; // Raw-integer case with destination type int32. template struct MakeRuyMulParamsImpl final { static void Run( const GemmParams& params, ruy::MulParams* ruy_mul_params) { ruy_mul_params->set_bias(params.bias); } }; template void MakeRuyMulParams( const GemmParams& params, ruy::MulParams* ruy_mul_params) { MakeRuyMulParamsImpl::Run( params, ruy_mul_params); } template struct GemmImplUsingRuy { static void Run( const MatrixParams& lhs_params, const LhsScalar* lhs_data, const MatrixParams& rhs_params, const RhsScalar* rhs_data, const MatrixParams& dst_params, DstScalar* dst_data, const GemmParams& params, CpuBackendContext* context) { ruy::Matrix ruy_lhs; ruy::Matrix ruy_rhs; ruy::Matrix ruy_dst; MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching()); MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching()); MakeRuyMatrix(dst_params, dst_data, &ruy_dst); ruy::MulParams ruy_mul_params; MakeRuyMulParams(params, &ruy_mul_params); ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, context->ruy_context(), &ruy_dst); } }; } // namespace detail } // namespace cpu_backend_gemm } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_