1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_COMMON_H_ 17 #define RUY_RUY_KERNEL_COMMON_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <type_traits> 22 23 #include "ruy/apply_multiplier.h" 24 #include "ruy/check_macros.h" 25 #include "ruy/mat.h" 26 #include "ruy/matrix.h" 27 #include "ruy/mul_params.h" 28 #include "ruy/opt_set.h" 29 #include "ruy/path.h" 30 #include "ruy/platform.h" 31 #include "ruy/profiler/instrumentation.h" 32 #include "ruy/side_pair.h" 33 #include "ruy/size_util.h" 34 #include "ruy/tune.h" 35 36 namespace ruy { 37 38 template <Path ThePath, typename LhsScalar, typename RhsScalar, 39 typename AccumScalar, typename DstScalar> 40 struct Kernel; 41 42 #define RUY_INHERIT_KERNEL(PARENT, CHILD) \ 43 template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ 44 typename AccumScalar> \ 45 struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \ 46 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \ 47 explicit Kernel(Tuning tuning) \ 48 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \ 49 tuning) {} \ 50 }; 51 52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. 53 // 54 // In other cases, we still define (empty) versions, so that dummy kernels 55 // can use the classes in function signatures. 56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \ 57 RUY_PLATFORM_X86 58 59 #define RUY_ASM_FLAG_HAS_BIAS 0x1 60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20 65 66 #define RUY_ASM_TYPE_ID_UINT8 1 67 #define RUY_ASM_TYPE_ID_INT8 2 68 #define RUY_ASM_TYPE_ID_INT16 3 69 #define RUY_ASM_TYPE_ID_INT32 4 70 71 template <typename DstScalar> 72 struct DstTypeId {}; 73 74 template <> 75 struct DstTypeId<std::uint8_t> { 76 static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; 77 }; 78 79 template <> 80 struct DstTypeId<std::int8_t> { 81 static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; 82 }; 83 84 template <> 85 struct DstTypeId<std::int16_t> { 86 static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; 87 }; 88 89 template <> 90 struct DstTypeId<std::int32_t> { 91 static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; 92 }; 93 94 template <int LhsCols, int RhsCols> 95 struct KernelParams8bit { 96 static constexpr int kMaxDstTypeSize = 4; 97 98 const std::int32_t* bias; 99 const std::int32_t* lhs_sums; 100 const std::int32_t* rhs_sums; 101 const std::int8_t* lhs_base_ptr; 102 const std::int32_t* multiplier_fixedpoint; 103 const std::int32_t* multiplier_exponent; 104 const std::int8_t* rhs_base_ptr; 105 void* dst_base_ptr; 106 std::int32_t lhs_zero_point; 107 std::int32_t rhs_zero_point; 108 std::int32_t dst_zero_point; 109 std::int32_t prod_zp_depth; 110 std::int32_t start_row; 111 std::int32_t start_col; 112 std::int32_t last_row; 113 std::int32_t last_col; 114 std::int32_t dst_rows; 115 std::int32_t dst_cols; 116 std::int32_t lhs_stride; 117 std::int32_t rhs_stride; 118 std::int32_t dst_stride; 119 std::int32_t depth; 120 std::int32_t clamp_min; 121 std::int32_t clamp_max; 122 std::uint8_t flags; 123 std::uint8_t dst_type_id; 124 const std::int32_t zero_data[LhsCols] = {0}; 125 std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; 126 std::int32_t multiplier_fixedpoint_buf[LhsCols]; 127 std::int32_t multiplier_exponent_buf[LhsCols]; 128 }; 129 130 template <typename DstScalar, int LhsCols, int RhsCols> 131 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, 132 const PMat<std::int8_t>& rhs, 133 const MulParams<std::int32_t, DstScalar>& mul_params, 134 int start_row, int start_col, int end_row, 135 int end_col, Mat<DstScalar>* dst, 136 KernelParams8bit<LhsCols, RhsCols>* params) { 137 using Params = KernelParams8bit<LhsCols, RhsCols>; 138 139 static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); 140 141 const int depth = lhs.layout.rows; 142 RUY_DCHECK_EQ(start_row % LhsCols, 0); 143 RUY_DCHECK_EQ(start_col % RhsCols, 0); 144 RUY_DCHECK_EQ(end_row % LhsCols, 0); 145 RUY_DCHECK_EQ(end_col % RhsCols, 0); 146 147 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 148 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 149 params->flags = 0; 150 params->bias = params->zero_data; 151 if (mul_params.bias()) { 152 params->bias = mul_params.bias(); 153 params->flags |= RUY_ASM_FLAG_HAS_BIAS; 154 } 155 if (lhs.sums) { 156 params->lhs_sums = lhs.sums; 157 params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; 158 } 159 if (rhs.sums) { 160 params->rhs_sums = rhs.sums; 161 params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; 162 } 163 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 164 params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 165 } 166 params->start_row = start_row; 167 params->start_col = start_col; 168 params->last_row = end_row - LhsCols; 169 params->last_col = end_col - RhsCols; 170 params->lhs_stride = lhs.layout.stride; 171 params->rhs_stride = rhs.layout.stride; 172 params->dst_stride = sizeof(DstScalar) * dst->layout.stride; 173 params->lhs_zero_point = lhs.zero_point; 174 params->rhs_zero_point = rhs.zero_point; 175 params->dst_zero_point = dst->zero_point; 176 params->depth = depth; 177 params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; 178 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; 179 if (mul_params.multiplier_fixedpoint_perchannel()) { 180 params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; 181 params->multiplier_fixedpoint = 182 mul_params.multiplier_fixedpoint_perchannel(); 183 params->multiplier_exponent = mul_params.multiplier_exponent_perchannel(); 184 } else { 185 params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; 186 params->multiplier_exponent = params->multiplier_exponent_buf; 187 for (int i = 0; i < LhsCols; i++) { 188 params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint(); 189 params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent(); 190 } 191 } 192 params->clamp_min = mul_params.clamp_min(); 193 params->clamp_max = mul_params.clamp_max(); 194 params->dst_rows = dst->layout.rows; 195 params->dst_cols = dst->layout.cols; 196 197 RUY_DCHECK_LT(params->last_row, params->dst_rows); 198 RUY_DCHECK_LT(params->last_col, params->dst_cols); 199 200 params->dst_type_id = DstTypeId<DstScalar>::kValue; 201 params->dst_base_ptr = 202 dst->data.get() + start_col * dst->layout.stride + start_row; 203 } 204 205 template <int LhsCols, int RhsCols> 206 struct KernelParamsFloat { 207 const float* lhs_base_ptr; 208 const float* rhs_base_ptr; 209 float* dst_base_ptr; 210 const float* bias; 211 std::int32_t start_row; 212 std::int32_t start_col; 213 std::int32_t last_row; 214 std::int32_t last_col; 215 std::int32_t dst_rows; 216 std::int32_t dst_cols; 217 std::int32_t lhs_stride; 218 std::int32_t rhs_stride; 219 std::int32_t dst_stride; 220 std::int32_t depth; 221 float clamp_min; 222 float clamp_max; 223 std::uint8_t flags; 224 const float zero_data[LhsCols] = {0}; 225 float dst_tmp_buf[LhsCols * RhsCols]; 226 }; 227 228 template <int LhsCols, int RhsCols> 229 inline void MakeKernelParamsFloat(const PMat<float>& lhs, 230 const PMat<float>& rhs, 231 const MulParams<float, float>& mul_params, 232 int start_row, int start_col, int end_row, 233 int end_col, Mat<float>* dst, 234 KernelParamsFloat<LhsCols, RhsCols>* params) { 235 const int depth = lhs.layout.rows; 236 RUY_DCHECK_EQ(start_row % LhsCols, 0); 237 RUY_DCHECK_EQ(start_col % RhsCols, 0); 238 RUY_DCHECK_EQ(end_row % LhsCols, 0); 239 RUY_DCHECK_EQ(end_col % RhsCols, 0); 240 241 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 242 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 243 params->dst_base_ptr = 244 dst->data.get() + start_col * dst->layout.stride + start_row; 245 246 std::uint8_t flags = 0; 247 params->bias = params->zero_data; 248 if (mul_params.bias()) { 249 params->bias = mul_params.bias(); 250 flags |= RUY_ASM_FLAG_HAS_BIAS; 251 } 252 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 253 flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 254 } 255 params->flags = flags; 256 params->start_row = start_row; 257 params->start_col = start_col; 258 params->last_row = end_row - LhsCols; 259 params->last_col = end_col - RhsCols; 260 params->lhs_stride = sizeof(float) * lhs.layout.stride; 261 params->rhs_stride = sizeof(float) * rhs.layout.stride; 262 params->dst_stride = sizeof(float) * dst->layout.stride; 263 params->depth = depth; 264 params->clamp_min = mul_params.clamp_min(); 265 params->clamp_max = mul_params.clamp_max(); 266 params->dst_rows = dst->layout.rows; 267 params->dst_cols = dst->layout.cols; 268 269 RUY_DCHECK_LT(params->last_row, params->dst_rows); 270 RUY_DCHECK_LT(params->last_col, params->dst_cols); 271 } 272 273 #else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 274 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 275 276 template <int LhsCols, int RhsCols> 277 struct KernelParams8bit {}; 278 279 template <int LhsCols, int RhsCols> 280 struct KernelParamsFloat {}; 281 282 #endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 283 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 284 285 } // namespace ruy 286 287 #endif // RUY_RUY_KERNEL_COMMON_H_ 288