1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output_msa.h: optimized MSA specializations of the templates in output.h. 16 17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 18 #define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 19 20 #include "output.h" 21 22 #include <msa.h> 23 24 namespace gemmlowp { 25 26 template <> 27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 // Signed saturate each 32-bit element to 9 bits 39 // (this takes full care of non-negative elements). 40 v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8); 41 // Pack every 32-bit element into 16 bits. 42 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 43 reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp))); 44 // Detect negative elements with arithmetic shift right (we 45 // get a 16-bit mask of all zeroes or all ones for every element). 46 v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15); 47 // Zero out negative elements. 48 signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( 49 reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0)); 50 // Pack every element into 8 bits. 51 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 52 reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs))); 53 // Return 4 uint8_t elements as uint32_t. 54 output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); 55 return output; 56 } 57 }; 58 59 template <> 60 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 61 RegBufferInt32<8>> { 62 typedef RegBufferInt32<8> InputType; 63 typedef RegBufferUint8<8> OutputType; 64 65 typedef OutputStageSaturatingCastToUint8 OutputStage; 66 67 OutputStageEvalBufferImpl(const OutputStage&) {} 68 69 OutputType Eval(InputType input) const { 70 OutputType output; 71 // Signed saturate each 32-bit element to 9 bits 72 // (this takes full care of non-negative elements). 73 v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8); 74 v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8); 75 // Pack every 32-bit element into 16 bits, 76 // combining all 8 elements into one vector. 77 tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 78 reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo))); 79 // Detect negative elements with arithmetic shift right (we 80 // get a 16-bit mask of all zeroes or all ones for every element). 81 v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15); 82 // Zero out negative elements. 83 signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( 84 reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0)); 85 // Pack every element into 8 bits. 86 tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 87 reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs))); 88 // Return 8 uint8_t elements as 2 uint32_t's. 89 output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0); 90 output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1); 91 return output; 92 } 93 }; 94 95 #define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3) \ 96 { \ 97 v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8); \ 98 v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8); \ 99 v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8); \ 100 v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8); \ 101 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \ 102 reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \ 103 tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \ 104 reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \ 105 v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \ 106 v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \ 107 signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \ 108 reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \ 109 signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \ 110 reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \ 111 signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \ 112 reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \ 113 out = reinterpret_cast<v16i8>(signs0); \ 114 } 115 116 template <> 117 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 118 RegBufferInt32<16>> { 119 typedef RegBufferInt32<16> InputType; 120 typedef RegBufferUint8<16> OutputType; 121 122 typedef OutputStageSaturatingCastToUint8 OutputStage; 123 124 OutputStageEvalBufferImpl(const OutputStage&) {} 125 126 OutputType Eval(InputType input) const { 127 OutputType output; 128 GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1], 129 input.reg[2], input.reg[3]); 130 return output; 131 } 132 }; 133 134 template <> 135 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 136 RegBufferInt32<32>> { 137 typedef RegBufferInt32<32> InputType; 138 typedef RegBufferUint8<32> OutputType; 139 140 typedef OutputStageSaturatingCastToUint8 OutputStage; 141 142 OutputStageEvalBufferImpl(const OutputStage&) {} 143 144 OutputType Eval(InputType input) const { 145 OutputType output; 146 GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1], 147 input.reg[2], input.reg[3]); 148 GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5], 149 input.reg[6], input.reg[7]); 150 return output; 151 } 152 }; 153 154 #undef GEMMLOWP_MIPS_SAT_U8_16 155 156 template <> 157 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 158 RegBufferInt32<4>> { 159 typedef RegBufferInt32<4> InputType; 160 typedef RegBufferInt16<4> OutputType; 161 162 typedef OutputStageSaturatingCastToInt16 OutputStage; 163 164 OutputStageEvalBufferImpl(const OutputStage&) {} 165 166 OutputType Eval(InputType input) const { 167 OutputType output; 168 // Signed saturate each 32-bit element to 16 bits. 169 v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w( 170 input.reg[0], 15)); 171 output.reg[0] = __builtin_msa_copy_s_h(tmp, 0); 172 output.reg[1] = __builtin_msa_copy_s_h(tmp, 2); 173 output.reg[2] = __builtin_msa_copy_s_h(tmp, 4); 174 output.reg[3] = __builtin_msa_copy_s_h(tmp, 6); 175 return output; 176 } 177 }; 178 179 #define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \ 180 { \ 181 v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \ 182 v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \ 183 out = __builtin_msa_pckev_h( \ 184 reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \ 185 } 186 187 template <> 188 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 189 RegBufferInt32<8>> { 190 typedef RegBufferInt32<8> InputType; 191 typedef RegBufferInt16<8> OutputType; 192 193 typedef OutputStageSaturatingCastToInt16 OutputStage; 194 195 OutputStageEvalBufferImpl(const OutputStage&) {} 196 197 OutputType Eval(InputType input) const { 198 OutputType output; 199 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 200 return output; 201 } 202 }; 203 204 template <> 205 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 206 RegBufferInt32<16>> { 207 typedef RegBufferInt32<16> InputType; 208 typedef RegBufferInt16<16> OutputType; 209 210 typedef OutputStageSaturatingCastToInt16 OutputStage; 211 212 OutputStageEvalBufferImpl(const OutputStage&) {} 213 214 OutputType Eval(InputType input) const { 215 OutputType output; 216 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 217 GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]); 218 return output; 219 } 220 }; 221 222 template <> 223 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 224 RegBufferInt32<32>> { 225 typedef RegBufferInt32<32> InputType; 226 typedef RegBufferInt16<32> OutputType; 227 228 typedef OutputStageSaturatingCastToInt16 OutputStage; 229 230 OutputStageEvalBufferImpl(const OutputStage&) {} 231 232 OutputType Eval(InputType input) const { 233 OutputType output; 234 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 235 GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]); 236 GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]); 237 GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]); 238 return output; 239 } 240 }; 241 242 #undef GEMMLOWP_MIPS_SAT_I16_8 243 244 template <typename DstType> 245 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 246 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 247 int col) { 248 if (DstType::kOrder == MapOrder::ColMajor) { 249 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 250 } else { 251 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 252 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 253 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 254 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 255 } 256 } 257 }; 258 259 template <typename DstType> 260 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 261 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 262 int col) { 263 if (DstType::kOrder == MapOrder::ColMajor) { 264 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 265 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 266 } else { 267 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 268 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 269 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 270 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 271 *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); 272 *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); 273 *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); 274 *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); 275 } 276 } 277 }; 278 279 template <typename DstType> 280 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { 281 static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, 282 int col) { 283 *dst->data(row + 0, col) = src.buf.reg[0]; 284 *dst->data(row + 1, col) = src.buf.reg[1]; 285 *dst->data(row + 2, col) = src.buf.reg[2]; 286 *dst->data(row + 3, col) = src.buf.reg[3]; 287 } 288 }; 289 290 template <typename DstType> 291 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { 292 static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, 293 int col) { 294 if (DstType::kOrder == MapOrder::ColMajor) { 295 StoreInt16x8(dst->data(row, col), src.buf.reg[0]); 296 } else { 297 *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0); 298 *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1); 299 *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2); 300 *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3); 301 *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4); 302 *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5); 303 *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6); 304 *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7); 305 } 306 } 307 }; 308 309 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 310 RegBlockInt32<4, 4> result; 311 v4i32 tmp0, tmp1; 312 tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]); 313 tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]); 314 result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d( 315 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 316 result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d( 317 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 318 tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]); 319 tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]); 320 result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d( 321 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 322 result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d( 323 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 324 return result; 325 } 326 327 template <typename DstType> 328 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 329 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 330 int col) { 331 if (DstType::kOrder == MapOrder::ColMajor) { 332 for (int i = 0; i < 4; i++) { 333 StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); 334 } 335 } else { 336 const auto transpose = Transpose(src); 337 for (int i = 0; i < 4; i++) { 338 StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); 339 } 340 } 341 } 342 }; 343 344 template <typename DstType> 345 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { 346 static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, 347 int col) { 348 std::int16_t buf[16]; 349 StoreInt16x8(buf + 0, src.buf.reg[0]); 350 StoreInt16x8(buf + 8, src.buf.reg[1]); 351 for (int i = 0; i < 4; i++) { 352 for (int j = 0; j < 4; j++) { 353 *dst->data(row + i, col + j) = buf[i + 4 * j]; 354 } 355 } 356 } 357 }; 358 359 template <typename DstType> 360 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 361 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 362 int col) { 363 if (DstType::kOrder == MapOrder::ColMajor) { 364 for (int i = 0; i < 4; i++) { 365 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 366 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 367 } 368 } else { 369 RegBlockInt32<4, 4> top; 370 top.buf.reg[0] = src.buf.reg[0]; 371 top.buf.reg[1] = src.buf.reg[2]; 372 top.buf.reg[2] = src.buf.reg[4]; 373 top.buf.reg[3] = src.buf.reg[6]; 374 const auto transpose_top = Transpose(top); 375 for (int i = 0; i < 4; i++) { 376 StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); 377 } 378 RegBlockInt32<4, 4> bottom; 379 bottom.buf.reg[0] = src.buf.reg[1]; 380 bottom.buf.reg[1] = src.buf.reg[3]; 381 bottom.buf.reg[2] = src.buf.reg[5]; 382 bottom.buf.reg[3] = src.buf.reg[7]; 383 const auto transpose_bottom = Transpose(bottom); 384 for (int i = 0; i < 4; i++) { 385 StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); 386 } 387 } 388 } 389 }; 390 391 template <typename DstType> 392 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { 393 static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, 394 int col) { 395 if (DstType::kOrder == MapOrder::ColMajor) { 396 for (int i = 0; i < 4; i++) { 397 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 398 } 399 } else { 400 std::int16_t buf[32]; 401 StoreInt16x8(buf + 0, src.buf.reg[0]); 402 StoreInt16x8(buf + 8, src.buf.reg[1]); 403 StoreInt16x8(buf + 16, src.buf.reg[2]); 404 StoreInt16x8(buf + 24, src.buf.reg[3]); 405 for (int i = 0; i < 8; i++) { 406 for (int j = 0; j < 4; j++) { 407 *dst->data(row + i, col + j) = buf[i + 8 * j]; 408 } 409 } 410 } 411 } 412 }; 413 414 template <typename DstType> 415 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 416 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 417 int col) { 418 if (DstType::kOrder == MapOrder::ColMajor) { 419 for (int i = 0; i < 8; i++) { 420 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 421 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 422 } 423 } else { 424 RegBlockInt32<4, 4> top_left; 425 top_left.buf.reg[0] = src.buf.reg[0]; 426 top_left.buf.reg[1] = src.buf.reg[2]; 427 top_left.buf.reg[2] = src.buf.reg[4]; 428 top_left.buf.reg[3] = src.buf.reg[6]; 429 const auto transpose_top_left = Transpose(top_left); 430 for (int i = 0; i < 4; i++) { 431 StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); 432 } 433 RegBlockInt32<4, 4> bottom_left; 434 bottom_left.buf.reg[0] = src.buf.reg[1]; 435 bottom_left.buf.reg[1] = src.buf.reg[3]; 436 bottom_left.buf.reg[2] = src.buf.reg[5]; 437 bottom_left.buf.reg[3] = src.buf.reg[7]; 438 const auto transpose_bottom_left = Transpose(bottom_left); 439 for (int i = 0; i < 4; i++) { 440 StoreInt32x4(dst->data(row + 4 + i, col), 441 transpose_bottom_left.buf.reg[i]); 442 } 443 RegBlockInt32<4, 4> top_right; 444 top_right.buf.reg[0] = src.buf.reg[8]; 445 top_right.buf.reg[1] = src.buf.reg[10]; 446 top_right.buf.reg[2] = src.buf.reg[12]; 447 top_right.buf.reg[3] = src.buf.reg[14]; 448 const auto transpose_top_right = Transpose(top_right); 449 for (int i = 0; i < 4; i++) { 450 StoreInt32x4(dst->data(row + i, col + 4), 451 transpose_top_right.buf.reg[i]); 452 } 453 RegBlockInt32<4, 4> bottom_right; 454 bottom_right.buf.reg[0] = src.buf.reg[9]; 455 bottom_right.buf.reg[1] = src.buf.reg[11]; 456 bottom_right.buf.reg[2] = src.buf.reg[13]; 457 bottom_right.buf.reg[3] = src.buf.reg[15]; 458 const auto transpose_bottom_right = Transpose(bottom_right); 459 for (int i = 0; i < 4; i++) { 460 StoreInt32x4(dst->data(row + 4 + i, col + 4), 461 transpose_bottom_right.buf.reg[i]); 462 } 463 } 464 } 465 }; 466 467 template <typename DstType> 468 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { 469 static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, 470 int col) { 471 if (DstType::kOrder == MapOrder::ColMajor) { 472 for (int i = 0; i < 8; i++) { 473 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 474 } 475 } else { 476 // top-left 4x4 477 v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1], 478 src.buf.reg[0])); 479 v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3], 480 src.buf.reg[2])); 481 v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0)); 482 v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0)); 483 // top-right 4x4 484 v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5], 485 src.buf.reg[4])); 486 v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7], 487 src.buf.reg[6])); 488 v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2)); 489 v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2)); 490 // bottom-left 4x4 491 v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1], 492 src.buf.reg[0])); 493 v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3], 494 src.buf.reg[2])); 495 v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4)); 496 v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4)); 497 // bottom-right 4x4 498 v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5], 499 src.buf.reg[4])); 500 v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7], 501 src.buf.reg[6])); 502 v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6)); 503 v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6)); 504 505 StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>( 506 __builtin_msa_ilvr_d(u2, u0))); 507 StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>( 508 __builtin_msa_ilvl_d(u2, u0))); 509 StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>( 510 __builtin_msa_ilvr_d(u3, u1))); 511 StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>( 512 __builtin_msa_ilvl_d(u3, u1))); 513 StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>( 514 __builtin_msa_ilvr_d(u6, u4))); 515 StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>( 516 __builtin_msa_ilvl_d(u6, u4))); 517 StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>( 518 __builtin_msa_ilvr_d(u7, u5))); 519 StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>( 520 __builtin_msa_ilvl_d(u7, u5))); 521 } 522 } 523 }; 524 525 template <typename DstType> 526 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 527 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 528 int col) { 529 if (DstType::kOrder == MapOrder::ColMajor) { 530 *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); 531 *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); 532 *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); 533 *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); 534 } else { 535 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 536 } 537 } 538 }; 539 540 template <typename DstType> 541 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 542 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 543 int col) { 544 const std::uint32_t src_reg = src.buf.reg[0]; 545 for (int i = 0; i < 4; i++) { 546 *dst->data(row + i, col) = (src_reg >> (8 * i)); 547 } 548 } 549 }; 550 551 template <typename DstType> 552 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 553 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 554 int col) { 555 for (int i = 0; i < 4; i++) { 556 *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); 557 } 558 for (int i = 0; i < 4; i++) { 559 *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); 560 } 561 } 562 }; 563 564 template <typename DstType> 565 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 566 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 567 int col) { 568 for (int i = 0; i < 4; i++) { 569 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 570 } 571 } 572 }; 573 574 template <typename DstType> 575 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 576 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 577 int col) { 578 std::uint8_t buf[16]; 579 StoreUint8x16(buf, src.buf.reg[0]); 580 for (int c = 0; c < 4; c++) { 581 for (int r = 0; r < 4; r++) { 582 *dst->data(row + r, col + c) = buf[r + 4 * c]; 583 } 584 } 585 } 586 }; 587 588 template <typename DstType> 589 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 590 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 591 int col) { 592 std::uint8_t buf[32]; 593 StoreUint8x16(buf, src.buf.reg[0]); 594 StoreUint8x16(buf + 16, src.buf.reg[1]); 595 for (int c = 0; c < 4; c++) { 596 for (int r = 0; r < 8; r++) { 597 *dst->data(row + r, col + c) = buf[r + 8 * c]; 598 } 599 } 600 } 601 }; 602 603 template <typename DstType> 604 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 605 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 606 int col) { 607 std::uint8_t buf[64]; 608 StoreUint8x16(buf, src.buf.reg[0]); 609 StoreUint8x16(buf + 16, src.buf.reg[1]); 610 StoreUint8x16(buf + 32, src.buf.reg[2]); 611 StoreUint8x16(buf + 48, src.buf.reg[3]); 612 for (int c = 0; c < 8; c++) { 613 for (int r = 0; r < 8; r++) { 614 *dst->data(row + r, col + c) = buf[r + 8 * c]; 615 } 616 } 617 } 618 }; 619 620 } // namespace gemmlowp 621 622 #endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 623