1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint_msa.h: optimized MSA specializations of the templates 16 // in fixedpoint.h. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_ 20 21 #include <msa.h> 22 23 namespace gemmlowp { 24 25 template <> 26 struct FixedPointRawTypeTraits<v4i32> { 27 typedef std::int32_t ScalarRawType; 28 static constexpr int kLanes = 4; 29 }; 30 31 template <> 32 struct FixedPointRawTypeTraits<v8i16> { 33 typedef std::int16_t ScalarRawType; 34 static constexpr int kLanes = 8; 35 }; 36 37 template <> 38 inline v4i32 BitAnd(v4i32 a, v4i32 b) { 39 return reinterpret_cast<v4i32>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a), 40 reinterpret_cast<v16u8>(b))); 41 } 42 43 template <> 44 inline v8i16 BitAnd(v8i16 a, v8i16 b) { 45 return reinterpret_cast<v8i16>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a), 46 reinterpret_cast<v16u8>(b))); 47 } 48 49 template <> 50 inline v4i32 BitOr(v4i32 a, v4i32 b) { 51 return reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a), 52 reinterpret_cast<v16u8>(b))); 53 } 54 55 template <> 56 inline v8i16 BitOr(v8i16 a, v8i16 b) { 57 return reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a), 58 reinterpret_cast<v16u8>(b))); 59 } 60 61 template <> 62 inline v4i32 BitXor(v4i32 a, v4i32 b) { 63 return reinterpret_cast<v4i32>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a), 64 reinterpret_cast<v16u8>(b))); 65 } 66 67 template <> 68 inline v8i16 BitXor(v8i16 a, v8i16 b) { 69 return reinterpret_cast<v8i16>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a), 70 reinterpret_cast<v16u8>(b))); 71 } 72 73 template <> 74 inline v4i32 BitNot(v4i32 a) { 75 return reinterpret_cast<v4i32>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a), 76 reinterpret_cast<v16u8>(a))); 77 } 78 79 template <> 80 inline v8i16 BitNot(v8i16 a) { 81 return reinterpret_cast<v8i16>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a), 82 reinterpret_cast<v16u8>(a))); 83 } 84 85 template <> 86 inline v4i32 Add(v4i32 a, v4i32 b) { 87 return __builtin_msa_addv_w(a, b); 88 } 89 90 template <> 91 inline v8i16 Add(v8i16 a, v8i16 b) { 92 return __builtin_msa_addv_h(a, b); 93 } 94 95 template <> 96 inline v4i32 Sub(v4i32 a, v4i32 b) { 97 return __builtin_msa_subv_w(a, b); 98 } 99 100 template <> 101 inline v8i16 Sub(v8i16 a, v8i16 b) { 102 return __builtin_msa_subv_h(a, b); 103 } 104 105 template <> 106 inline v4i32 Neg(v4i32 a) { 107 v4i32 zeroes = __builtin_msa_ldi_w(0); 108 return __builtin_msa_subv_w(zeroes, a); 109 } 110 111 template <> 112 inline v8i16 Neg(v8i16 a) { 113 v8i16 zeroes = __builtin_msa_ldi_h(0); 114 return __builtin_msa_subv_h(zeroes, a); 115 } 116 117 template <> 118 inline v4i32 ShiftLeft(v4i32 a, int offset) { 119 return __builtin_msa_sll_w(a, __builtin_msa_fill_w(offset)); 120 } 121 122 template <> 123 inline v8i16 ShiftLeft(v8i16 a, int offset) { 124 return __builtin_msa_sll_h(a, __builtin_msa_fill_h(offset)); 125 } 126 127 template <> 128 inline v4i32 ShiftRight(v4i32 a, int offset) { 129 return __builtin_msa_sra_w(a, __builtin_msa_fill_w(offset)); 130 } 131 132 template <> 133 inline v8i16 ShiftRight(v8i16 a, int offset) { 134 return __builtin_msa_sra_h(a, __builtin_msa_fill_h(offset)); 135 } 136 137 template <> 138 inline v4i32 SelectUsingMask(v4i32 if_mask, v4i32 then_val, v4i32 else_val) { 139 if_mask = reinterpret_cast<v4i32>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask), 140 reinterpret_cast<v16u8>(else_val), 141 reinterpret_cast<v16u8>(then_val))); 142 return if_mask; 143 } 144 145 template <> 146 inline v8i16 SelectUsingMask(v8i16 if_mask, v8i16 then_val, v8i16 else_val) { 147 if_mask = reinterpret_cast<v8i16>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask), 148 reinterpret_cast<v16u8>(else_val), 149 reinterpret_cast<v16u8>(then_val))); 150 return if_mask; 151 } 152 153 template <> 154 inline v4i32 MaskIfEqual(v4i32 a, v4i32 b) { 155 return __builtin_msa_ceq_w(a, b); 156 } 157 158 template <> 159 inline v8i16 MaskIfEqual(v8i16 a, v8i16 b) { 160 return __builtin_msa_ceq_h(a, b); 161 } 162 163 template <> 164 inline v4i32 MaskIfNotEqual(v4i32 a, v4i32 b) { 165 return BitNot(MaskIfEqual(a, b)); 166 } 167 168 template <> 169 inline v8i16 MaskIfNotEqual(v8i16 a, v8i16 b) { 170 return BitNot(MaskIfEqual(a, b)); 171 } 172 173 template <> 174 inline v4i32 MaskIfZero(v4i32 a) { 175 return __builtin_msa_ceqi_w(a, 0); 176 } 177 178 template <> 179 inline v8i16 MaskIfZero(v8i16 a) { 180 return __builtin_msa_ceqi_h(a, 0); 181 } 182 183 template <> 184 inline v4i32 MaskIfNonZero(v4i32 a) { 185 return BitNot(MaskIfZero(a)); 186 } 187 188 template <> 189 inline v8i16 MaskIfNonZero(v8i16 a) { 190 return BitNot(MaskIfZero(a)); 191 } 192 193 template <> 194 inline v4i32 MaskIfGreaterThan(v4i32 a, v4i32 b) { 195 return __builtin_msa_clt_s_w(b, a); 196 } 197 198 template <> 199 inline v8i16 MaskIfGreaterThan(v8i16 a, v8i16 b) { 200 return __builtin_msa_clt_s_h(b, a); 201 } 202 203 template <> 204 inline v4i32 MaskIfGreaterThanOrEqual(v4i32 a, v4i32 b) { 205 return __builtin_msa_cle_s_w(b, a); 206 } 207 208 template <> 209 inline v8i16 MaskIfGreaterThanOrEqual(v8i16 a, v8i16 b) { 210 return __builtin_msa_cle_s_h(b, a); 211 } 212 213 template <> 214 inline v4i32 MaskIfLessThan(v4i32 a, v4i32 b) { 215 return __builtin_msa_clt_s_w(a, b); 216 } 217 218 template <> 219 inline v8i16 MaskIfLessThan(v8i16 a, v8i16 b) { 220 return __builtin_msa_clt_s_h(a, b); 221 } 222 223 template <> 224 inline v4i32 MaskIfLessThanOrEqual(v4i32 a, v4i32 b) { 225 return __builtin_msa_cle_s_w(a, b); 226 } 227 228 template <> 229 inline v8i16 MaskIfLessThanOrEqual(v8i16 a, v8i16 b) { 230 return __builtin_msa_cle_s_h(a, b); 231 } 232 233 template <> 234 inline bool All(v4i32 a) { 235 return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a))); 236 } 237 238 template <> 239 inline bool All(v8i16 a) { 240 return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a))); 241 } 242 243 template <> 244 inline bool Any(v4i32 a) { 245 return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a)); 246 } 247 248 template <> 249 inline bool Any(v8i16 a) { 250 return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a)); 251 } 252 253 template <> 254 inline v4i32 RoundingHalfSum(v4i32 a, v4i32 b) { 255 return __builtin_msa_aver_s_w(a, b); 256 } 257 258 template <> 259 inline v8i16 RoundingHalfSum(v8i16 a, v8i16 b) { 260 return __builtin_msa_aver_s_h(a, b); 261 } 262 263 template <> 264 inline v4i32 SaturatingRoundingDoublingHighMul(v4i32 a, v4i32 b) { 265 return __builtin_msa_mulr_q_w(a, b); 266 } 267 268 template <> 269 inline v8i16 SaturatingRoundingDoublingHighMul(v8i16 a, v8i16 b) { 270 return __builtin_msa_mulr_q_h(a, b); 271 } 272 273 template <int Exponent> 274 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, 1> { 275 static v4i32 eval(v4i32 x) { 276 static_assert(Exponent >= 0 && Exponent < 32, ""); 277 if (Exponent < 5) { 278 for (int i = 0; i < Exponent; i++) { 279 x = __builtin_msa_adds_s_w(x, x); 280 } 281 return x; 282 } else { 283 // Saturate each signed 32-bit element to (32 - Exponent) 284 // bits (this takes full care of negative elements). 285 v4i32 res = __builtin_msa_sat_s_w(x, 31 - Exponent); 286 // Set tmp to 0x7FFFFFFF for those elements which staturated 287 // to smaller (positive) values and 0 for all others. 288 v4i32 tmp = __builtin_msa_srli_w(__builtin_msa_clt_s_w(res, x), 1); 289 // Shift the saturated elements. The positive saturated elements 290 // will have Exponent trailing zero bits after the shift. Those 291 // need to be ones, not zeroes. 292 res = __builtin_msa_slli_w(res, Exponent); 293 // Finally, set those trailing zero bits to ones. 294 res = reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res), 295 reinterpret_cast<v16u8>(tmp))); 296 return res; 297 } 298 } 299 }; 300 301 template <int Exponent> 302 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> { 303 static v8i16 eval(v8i16 x) { 304 static_assert(Exponent >= 0 && Exponent < 16, ""); 305 if (Exponent < 5) { 306 for (int i = 0; i < Exponent; i++) { 307 x = __builtin_msa_adds_s_h(x, x); 308 } 309 return x; 310 } else { 311 // Saturate each signed 16-bit element to (16 - Exponent) 312 // bits (this takes full care of negative elements). 313 v8i16 res = __builtin_msa_sat_s_h(x, 15 - Exponent); 314 // Set tmp to 0x7FFF for those elements which staturated 315 // to smaller (positive) values and 0 for all others. 316 v8i16 tmp = __builtin_msa_srli_h(__builtin_msa_clt_s_h(res, x), 1); 317 // Shift the saturated elements. The positive saturated elements 318 // will have Exponent trailing zero bits after the shift. Those 319 // need to be ones, not zeroes. 320 res = __builtin_msa_slli_h(res, Exponent); 321 // Finally, set those trailing zero bits to ones. 322 res = reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res), 323 reinterpret_cast<v16u8>(tmp))); 324 return res; 325 } 326 } 327 }; 328 329 template <int Exponent> 330 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> { 331 static v4i32 eval(v4i32 x) { 332 static_assert(-31 <= Exponent && Exponent <= -1, ""); 333 // Isolate the sign bits. 334 v4i32 sign = __builtin_msa_srli_w(x, 31); 335 // Decrement the negative elements by 1 (with saturation). 336 x = __builtin_msa_subs_s_w(x, sign); 337 // Arithmetic shift right with rounding. 338 // The srari instruction rounds all midpoint values towards +infinity. 339 // It will correctly round negative midpoint values as we just 340 // decremented the negative values by 1. 341 return __builtin_msa_srari_w(x, -Exponent); 342 } 343 }; 344 345 template <int Exponent> 346 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> { 347 static v8i16 eval(v8i16 x) { 348 static_assert(-15 <= Exponent && Exponent <= -1, ""); 349 // Isolate the sign bits. 350 v8i16 sign = __builtin_msa_srli_h(x, 15); 351 // Decrement the negative elements by 1 (with saturation). 352 x = __builtin_msa_subs_s_h(x, sign); 353 // Arithmetic shift right with rounding. 354 // The srari instruction rounds all midpoint values towards +infinity. 355 // It will correctly round negative midpoint values as we just 356 // decremented the negative values by 1. 357 return __builtin_msa_srari_h(x, -Exponent); 358 } 359 }; 360 361 template <> 362 inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) { 363 v4i32 e = __builtin_msa_fill_w(exponent); 364 // Isolate the sign bits. 365 v4i32 sign = __builtin_msa_srli_w(x, 31); 366 // Reset them to 0 if exponent is 0. 367 sign = __builtin_msa_min_s_w(sign, e); 368 // Decrement the negative elements by 1 (with saturation) 369 // if exponent is non-zero. 370 x = __builtin_msa_subs_s_w(x, sign); 371 // Arithmetic shift right with rounding. 372 // The srar instruction rounds all midpoint values towards +infinity. 373 // It will correctly round negative midpoint values as we just 374 // decremented the negative values by 1. 375 return __builtin_msa_srar_w(x, e); 376 } 377 378 template <> 379 inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) { 380 v8i16 e = __builtin_msa_fill_h(exponent); 381 // Isolate the sign bits. 382 v8i16 sign = __builtin_msa_srli_h(x, 15); 383 // Reset them to 0 if exponent is 0. 384 sign = __builtin_msa_min_s_h(sign, e); 385 // Decrement the negative elements by 1 (with saturation) 386 // if exponent is non-zero. 387 x = __builtin_msa_subs_s_h(x, sign); 388 // Arithmetic shift right with rounding. 389 // The srar instruction rounds all midpoint values towards +infinity. 390 // It will correctly round negative midpoint values as we just 391 // decremented the negative values by 1. 392 return __builtin_msa_srar_h(x, e); 393 } 394 395 template <> 396 inline v4i32 Dup<v4i32>(std::int32_t x) { 397 return __builtin_msa_fill_w(x); 398 } 399 400 template <> 401 inline v8i16 Dup<v8i16>(std::int16_t x) { 402 return __builtin_msa_fill_h(x); 403 } 404 405 // So far this is only needed for int16. 406 template <> 407 inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) { 408 return __builtin_msa_adds_s_h(a, b); 409 } 410 411 } // end namespace gemmlowp 412 413 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_ 414