1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_COMMON_H_ 17 #define RUY_RUY_KERNEL_COMMON_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <type_traits> 22 23 #include "ruy/apply_multiplier.h" 24 #include "ruy/check_macros.h" 25 #include "ruy/mat.h" 26 #include "ruy/matrix.h" 27 #include "ruy/mul_params.h" 28 #include "ruy/opt_set.h" 29 #include "ruy/path.h" 30 #include "ruy/platform.h" 31 #include "ruy/profiler/instrumentation.h" 32 #include "ruy/side_pair.h" 33 #include "ruy/size_util.h" 34 #include "ruy/tune.h" 35 36 namespace ruy { 37 38 template <Path ThePath, typename LhsScalar, typename RhsScalar, 39 typename AccumScalar, typename DstScalar> 40 struct Kernel; 41 42 #define RUY_INHERIT_KERNEL(PARENT, CHILD) \ 43 template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ 44 typename AccumScalar> \ 45 struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \ 46 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \ 47 explicit Kernel(Tuning tuning) \ 48 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \ 49 tuning) {} \ 50 }; 51 52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. 53 // 54 // In other cases, we still define (empty) versions, so that dummy kernels 55 // can use the classes in function signatures. 56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \ 57 RUY_PLATFORM_X86 58 59 #define RUY_ASM_FLAG_HAS_BIAS 0x1 60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20 65 66 #define RUY_ASM_TYPE_ID_UINT8 1 67 #define RUY_ASM_TYPE_ID_INT8 2 68 #define RUY_ASM_TYPE_ID_INT16 3 69 #define RUY_ASM_TYPE_ID_INT32 4 70 71 template <typename DstScalar> 72 struct DstTypeId {}; 73 74 template <> 75 struct DstTypeId<std::uint8_t> { 76 static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; 77 }; 78 79 template <> 80 struct DstTypeId<std::int8_t> { 81 static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; 82 }; 83 84 template <> 85 struct DstTypeId<std::int16_t> { 86 static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; 87 }; 88 89 template <> 90 struct DstTypeId<std::int32_t> { 91 static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; 92 }; 93 94 template <int LhsCols, int RhsCols> 95 struct KernelParams8bit { 96 static constexpr int kMaxDstTypeSize = 4; 97 98 const std::int32_t* bias; 99 const std::int32_t* lhs_sums; 100 const std::int32_t* rhs_sums; 101 const std::int8_t* lhs_base_ptr; 102 const std::int32_t* multiplier_fixedpoint; 103 const std::int32_t* multiplier_exponent; 104 const std::int8_t* rhs_base_ptr; 105 void* dst_base_ptr; 106 std::int32_t lhs_zero_point; 107 std::int32_t rhs_zero_point; 108 std::int32_t dst_zero_point; 109 std::int32_t prod_zp_depth; 110 std::int32_t start_row; 111 std::int32_t start_col; 112 std::int32_t last_row; 113 std::int32_t last_col; 114 std::int32_t dst_rows; 115 std::int32_t dst_cols; 116 std::int32_t lhs_stride; 117 std::int32_t rhs_stride; 118 std::int32_t dst_stride; 119 std::int32_t depth; 120 std::int32_t clamp_min; 121 std::int32_t clamp_max; 122 std::uint8_t flags; 123 std::uint8_t dst_type_id; 124 const std::int32_t zero_data[LhsCols] = {0}; 125 std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; 126 std::int32_t multiplier_fixedpoint_buf[LhsCols]; 127 std::int32_t multiplier_exponent_buf[LhsCols]; 128 }; 129 130 template <typename DstScalar, int LhsCols, int RhsCols> 131 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, 132 const PMat<std::int8_t>& rhs, 133 const MulParams<std::int32_t, DstScalar>& mul_params, 134 int start_row, int start_col, int end_row, 135 int end_col, Mat<DstScalar>* dst, 136 KernelParams8bit<LhsCols, RhsCols>* params) { 137 using Params = KernelParams8bit<LhsCols, RhsCols>; 138 139 static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); 140 141 const int depth = lhs.layout.rows; 142 RUY_DCHECK_EQ(start_row % LhsCols, 0); 143 RUY_DCHECK_EQ(start_col % RhsCols, 0); 144 RUY_DCHECK_EQ(end_row % LhsCols, 0); 145 RUY_DCHECK_EQ(end_col % RhsCols, 0); 146 147 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 148 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 149 params->flags = 0; 150 params->bias = params->zero_data; 151 if (mul_params.bias()) { 152 params->bias = mul_params.bias(); 153 params->flags |= RUY_ASM_FLAG_HAS_BIAS; 154 } 155 if (lhs.sums) { 156 params->lhs_sums = lhs.sums; 157 params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; 158 } 159 if (rhs.sums) { 160 params->rhs_sums = rhs.sums; 161 params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; 162 } 163 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 164 params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 165 } 166 params->start_row = start_row; 167 params->start_col = start_col; 168 params->last_row = end_row - LhsCols; 169 params->last_col = end_col - RhsCols; 170 params->lhs_stride = lhs.layout.stride; 171 params->rhs_stride = rhs.layout.stride; 172 params->dst_stride = sizeof(DstScalar) * dst->layout.stride; 173 params->lhs_zero_point = lhs.zero_point; 174 params->rhs_zero_point = rhs.zero_point; 175 params->dst_zero_point = dst->zero_point; 176 params->depth = depth; 177 params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; 178 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; 179 if (mul_params.multiplier_fixedpoint_perchannel()) { 180 // Temporary release-assert to debug some crashes in an application. 181 RUY_CHECK(mul_params.multiplier_exponent_perchannel()); 182 params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; 183 params->multiplier_fixedpoint = 184 mul_params.multiplier_fixedpoint_perchannel(); 185 params->multiplier_exponent = mul_params.multiplier_exponent_perchannel(); 186 } else { 187 params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; 188 params->multiplier_exponent = params->multiplier_exponent_buf; 189 for (int i = 0; i < LhsCols; i++) { 190 params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint(); 191 params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent(); 192 } 193 } 194 params->clamp_min = mul_params.clamp_min(); 195 params->clamp_max = mul_params.clamp_max(); 196 params->dst_rows = dst->layout.rows; 197 params->dst_cols = dst->layout.cols; 198 199 RUY_DCHECK_LT(params->last_row, params->dst_rows); 200 RUY_DCHECK_LT(params->last_col, params->dst_cols); 201 202 params->dst_type_id = DstTypeId<DstScalar>::kValue; 203 params->dst_base_ptr = 204 dst->data.get() + start_col * dst->layout.stride + start_row; 205 206 // Temporary release-asserts to debug some crashes in an application. 207 RUY_CHECK(params->multiplier_fixedpoint); 208 RUY_CHECK(params->multiplier_exponent); 209 RUY_CHECK(params->bias); 210 } 211 212 template <int LhsCols, int RhsCols> 213 struct KernelParamsFloat { 214 const float* lhs_base_ptr; 215 const float* rhs_base_ptr; 216 float* dst_base_ptr; 217 const float* bias; 218 std::int32_t start_row; 219 std::int32_t start_col; 220 std::int32_t last_row; 221 std::int32_t last_col; 222 std::int32_t dst_rows; 223 std::int32_t dst_cols; 224 std::int32_t lhs_stride; 225 std::int32_t rhs_stride; 226 std::int32_t dst_stride; 227 std::int32_t depth; 228 float clamp_min; 229 float clamp_max; 230 std::uint8_t flags; 231 const float zero_data[LhsCols] = {0}; 232 float dst_tmp_buf[LhsCols * RhsCols]; 233 }; 234 235 template <int LhsCols, int RhsCols> 236 inline void MakeKernelParamsFloat(const PMat<float>& lhs, 237 const PMat<float>& rhs, 238 const MulParams<float, float>& mul_params, 239 int start_row, int start_col, int end_row, 240 int end_col, Mat<float>* dst, 241 KernelParamsFloat<LhsCols, RhsCols>* params) { 242 const int depth = lhs.layout.rows; 243 RUY_DCHECK_EQ(start_row % LhsCols, 0); 244 RUY_DCHECK_EQ(start_col % RhsCols, 0); 245 RUY_DCHECK_EQ(end_row % LhsCols, 0); 246 RUY_DCHECK_EQ(end_col % RhsCols, 0); 247 248 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 249 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 250 params->dst_base_ptr = 251 dst->data.get() + start_col * dst->layout.stride + start_row; 252 253 std::uint8_t flags = 0; 254 params->bias = params->zero_data; 255 if (mul_params.bias()) { 256 params->bias = mul_params.bias(); 257 flags |= RUY_ASM_FLAG_HAS_BIAS; 258 } 259 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 260 flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 261 } 262 params->flags = flags; 263 params->start_row = start_row; 264 params->start_col = start_col; 265 params->last_row = end_row - LhsCols; 266 params->last_col = end_col - RhsCols; 267 params->lhs_stride = sizeof(float) * lhs.layout.stride; 268 params->rhs_stride = sizeof(float) * rhs.layout.stride; 269 params->dst_stride = sizeof(float) * dst->layout.stride; 270 params->depth = depth; 271 params->clamp_min = mul_params.clamp_min(); 272 params->clamp_max = mul_params.clamp_max(); 273 params->dst_rows = dst->layout.rows; 274 params->dst_cols = dst->layout.cols; 275 276 RUY_DCHECK_LT(params->last_row, params->dst_rows); 277 RUY_DCHECK_LT(params->last_col, params->dst_cols); 278 } 279 280 #else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 281 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 282 283 template <int LhsCols, int RhsCols> 284 struct KernelParams8bit {}; 285 286 template <int LhsCols, int RhsCols> 287 struct KernelParamsFloat {}; 288 289 #endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 290 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 291 292 } // namespace ruy 293 294 #endif // RUY_RUY_KERNEL_COMMON_H_ 295