1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output_neon.h: optimized NEON specializations of the templates in output.h. 16 17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 19 20 #include "output.h" 21 22 #include <arm_neon.h> 23 24 namespace gemmlowp { 25 26 template <> 27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 int16x4_t res_16 = vqmovn_s32(input.reg[0]); 39 uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16)); 40 output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0); 41 return output; 42 } 43 }; 44 45 template <> 46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 47 RegBufferInt32<8>> { 48 typedef RegBufferInt32<8> InputType; 49 typedef RegBufferUint8<8> OutputType; 50 51 typedef OutputStageSaturatingCastToUint8 OutputStage; 52 53 OutputStageEvalBufferImpl(const OutputStage&) {} 54 55 OutputType Eval(InputType input) const { 56 OutputType output; 57 int16x8_t res_16 = 58 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 59 output.reg[0] = vqmovun_s16(res_16); 60 return output; 61 } 62 }; 63 64 template <> 65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 66 RegBufferInt32<16>> { 67 typedef RegBufferInt32<16> InputType; 68 typedef RegBufferUint8<16> OutputType; 69 70 typedef OutputStageSaturatingCastToUint8 OutputStage; 71 72 OutputStageEvalBufferImpl(const OutputStage&) {} 73 74 OutputType Eval(InputType input) const { 75 OutputType output; 76 int16x8_t res_16_0 = 77 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 78 int16x8_t res_16_1 = 79 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 80 output.reg[0] = vqmovun_s16(res_16_0); 81 output.reg[1] = vqmovun_s16(res_16_1); 82 return output; 83 } 84 }; 85 86 template <> 87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 88 RegBufferInt32<32>> { 89 typedef RegBufferInt32<32> InputType; 90 typedef RegBufferUint8<32> OutputType; 91 92 typedef OutputStageSaturatingCastToUint8 OutputStage; 93 94 OutputStageEvalBufferImpl(const OutputStage&) {} 95 96 OutputType Eval(InputType input) const { 97 OutputType output; 98 int16x8_t res_16[4]; 99 for (int i = 0; i < 4; i++) { 100 res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), 101 vqmovn_s32(input.reg[2 * i + 1])); 102 } 103 for (int i = 0; i < 4; i++) { 104 output.reg[i] = vqmovun_s16(res_16[i]); 105 } 106 return output; 107 } 108 }; 109 110 template <> 111 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 112 RegBufferInt32<4>> { 113 typedef RegBufferInt32<4> InputType; 114 typedef RegBufferInt16<4> OutputType; 115 116 typedef OutputStageSaturatingCastToInt16 OutputStage; 117 118 OutputStageEvalBufferImpl(const OutputStage&) {} 119 120 OutputType Eval(InputType input) const { 121 OutputType output; 122 output.reg[0] = vqmovn_s32(input.reg[0]); 123 return output; 124 } 125 }; 126 127 template <> 128 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 129 RegBufferInt32<8>> { 130 typedef RegBufferInt32<8> InputType; 131 typedef RegBufferInt16<8> OutputType; 132 133 typedef OutputStageSaturatingCastToInt16 OutputStage; 134 135 OutputStageEvalBufferImpl(const OutputStage&) {} 136 137 OutputType Eval(InputType input) const { 138 OutputType output; 139 output.reg[0] = 140 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 141 return output; 142 } 143 }; 144 145 template <> 146 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 147 RegBufferInt32<16>> { 148 typedef RegBufferInt32<16> InputType; 149 typedef RegBufferInt16<16> OutputType; 150 151 typedef OutputStageSaturatingCastToInt16 OutputStage; 152 153 OutputStageEvalBufferImpl(const OutputStage&) {} 154 155 OutputType Eval(InputType input) const { 156 OutputType output; 157 output.reg[0] = 158 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 159 output.reg[1] = 160 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 161 return output; 162 } 163 }; 164 165 template <> 166 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 167 RegBufferInt32<32>> { 168 typedef RegBufferInt32<32> InputType; 169 typedef RegBufferInt16<32> OutputType; 170 171 typedef OutputStageSaturatingCastToInt16 OutputStage; 172 173 OutputStageEvalBufferImpl(const OutputStage&) {} 174 175 OutputType Eval(InputType input) const { 176 OutputType output; 177 output.reg[0] = 178 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 179 output.reg[1] = 180 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 181 output.reg[2] = 182 vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5])); 183 output.reg[3] = 184 vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7])); 185 return output; 186 } 187 }; 188 189 template <typename DstType> 190 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 191 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 192 int col) { 193 if (DstType::kOrder == MapOrder::ColMajor) { 194 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 195 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 196 } else { 197 vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0); 198 vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1); 199 vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2); 200 vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3); 201 vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0); 202 vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1); 203 vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2); 204 vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3); 205 } 206 } 207 }; 208 209 template <typename DstType> 210 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { 211 static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, 212 int col) { 213 if (DstType::kOrder == MapOrder::ColMajor) { 214 StoreInt16x4(dst->data(row, col), src.buf.reg[0]); 215 } else { 216 vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0); 217 vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1); 218 vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2); 219 vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3); 220 } 221 } 222 }; 223 224 template <typename DstType> 225 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { 226 static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, 227 int col) { 228 if (DstType::kOrder == MapOrder::ColMajor) { 229 StoreInt16x8(dst->data(row, col), src.buf.reg[0]); 230 } else { 231 vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0); 232 vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1); 233 vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2); 234 vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3); 235 vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4); 236 vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5); 237 vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6); 238 vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7); 239 } 240 } 241 }; 242 243 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 244 const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]); 245 const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]); 246 RegBlockInt32<4, 4> result; 247 result.buf.reg[0] = 248 vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0])); 249 result.buf.reg[1] = 250 vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1])); 251 result.buf.reg[2] = 252 vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0])); 253 result.buf.reg[3] = 254 vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1])); 255 return result; 256 } 257 258 template <typename DstType> 259 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 260 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 261 int col) { 262 const auto& block = 263 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 264 std::int32_t* dst_ptr = dst->data(row, col); 265 int stride = dst->stride(); 266 for (int i = 0; i < 4; i++) { 267 vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]); 268 } 269 } 270 }; 271 272 template <typename DstType> 273 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { 274 static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, 275 int col) { 276 if (DstType::kOrder == MapOrder::ColMajor) { 277 vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0])); 278 vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0])); 279 vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1])); 280 vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1])); 281 } else { 282 const int16x4x2_t t0 = 283 vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0])); 284 const int16x4x2_t t1 = 285 vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1])); 286 const int32x4x2_t t = 287 vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])), 288 vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1]))); 289 vst1_s16(dst->data(row + 0, col), 290 vget_low_s16(vreinterpretq_s16_s32(t.val[0]))); 291 vst1_s16(dst->data(row + 1, col), 292 vget_high_s16(vreinterpretq_s16_s32(t.val[0]))); 293 vst1_s16(dst->data(row + 2, col), 294 vget_low_s16(vreinterpretq_s16_s32(t.val[1]))); 295 vst1_s16(dst->data(row + 3, col), 296 vget_high_s16(vreinterpretq_s16_s32(t.val[1]))); 297 } 298 } 299 }; 300 301 template <typename DstType> 302 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 303 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 304 int col) { 305 std::int32_t* dst_ptr = dst->data(row, col); 306 if (DstType::kOrder == MapOrder::ColMajor) { 307 int col_stride = dst->cols_stride(); 308 for (int i = 0; i < 4; i++) { 309 vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]); 310 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 311 } 312 } else { 313 int row_stride = dst->rows_stride(); 314 RegBlockInt32<4, 4> top; 315 top.buf.reg[0] = src.buf.reg[0]; 316 top.buf.reg[1] = src.buf.reg[2]; 317 top.buf.reg[2] = src.buf.reg[4]; 318 top.buf.reg[3] = src.buf.reg[6]; 319 const auto transpose_top = Transpose(top); 320 for (int i = 0; i < 4; i++) { 321 vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]); 322 } 323 RegBlockInt32<4, 4> bottom; 324 bottom.buf.reg[0] = src.buf.reg[1]; 325 bottom.buf.reg[1] = src.buf.reg[3]; 326 bottom.buf.reg[2] = src.buf.reg[5]; 327 bottom.buf.reg[3] = src.buf.reg[7]; 328 const auto transpose_bottom = Transpose(bottom); 329 for (int i = 0; i < 4; i++) { 330 vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]); 331 } 332 } 333 } 334 }; 335 336 template <typename DstType> 337 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { 338 static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, 339 int col) { 340 if (DstType::kOrder == MapOrder::ColMajor) { 341 vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]); 342 vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]); 343 vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]); 344 vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]); 345 } else { 346 const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]); 347 const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]); 348 const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]), 349 vreinterpretq_s32_s16(t1.val[0])); 350 const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]), 351 vreinterpretq_s32_s16(t1.val[1])); 352 vst1_s16(dst->data(row + 0, col), 353 vget_low_s16(vreinterpretq_s16_s32(u0.val[0]))); 354 vst1_s16(dst->data(row + 1, col), 355 vget_low_s16(vreinterpretq_s16_s32(u1.val[0]))); 356 vst1_s16(dst->data(row + 2, col), 357 vget_low_s16(vreinterpretq_s16_s32(u0.val[1]))); 358 vst1_s16(dst->data(row + 3, col), 359 vget_low_s16(vreinterpretq_s16_s32(u1.val[1]))); 360 vst1_s16(dst->data(row + 4, col), 361 vget_high_s16(vreinterpretq_s16_s32(u0.val[0]))); 362 vst1_s16(dst->data(row + 5, col), 363 vget_high_s16(vreinterpretq_s16_s32(u1.val[0]))); 364 vst1_s16(dst->data(row + 6, col), 365 vget_high_s16(vreinterpretq_s16_s32(u0.val[1]))); 366 vst1_s16(dst->data(row + 7, col), 367 vget_high_s16(vreinterpretq_s16_s32(u1.val[1]))); 368 } 369 } 370 }; 371 372 template <typename DstType> 373 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 374 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 375 int col) { 376 std::int32_t* dst_ptr = dst->data(row, col); 377 if (DstType::kOrder == MapOrder::ColMajor) { 378 int col_stride = dst->cols_stride(); 379 for (int i = 0; i < 8; i++) { 380 vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]); 381 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 382 } 383 } else { 384 int row_stride = dst->rows_stride(); 385 RegBlockInt32<4, 4> top_left; 386 top_left.buf.reg[0] = src.buf.reg[0]; 387 top_left.buf.reg[1] = src.buf.reg[2]; 388 top_left.buf.reg[2] = src.buf.reg[4]; 389 top_left.buf.reg[3] = src.buf.reg[6]; 390 const auto transpose_top_left = Transpose(top_left); 391 for (int i = 0; i < 4; i++) { 392 vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]); 393 } 394 RegBlockInt32<4, 4> bottom_left; 395 bottom_left.buf.reg[0] = src.buf.reg[1]; 396 bottom_left.buf.reg[1] = src.buf.reg[3]; 397 bottom_left.buf.reg[2] = src.buf.reg[5]; 398 bottom_left.buf.reg[3] = src.buf.reg[7]; 399 const auto transpose_bottom_left = Transpose(bottom_left); 400 for (int i = 0; i < 4; i++) { 401 vst1q_s32(dst_ptr + (i + 4) * row_stride, 402 transpose_bottom_left.buf.reg[i]); 403 } 404 RegBlockInt32<4, 4> top_right; 405 top_right.buf.reg[0] = src.buf.reg[8]; 406 top_right.buf.reg[1] = src.buf.reg[10]; 407 top_right.buf.reg[2] = src.buf.reg[12]; 408 top_right.buf.reg[3] = src.buf.reg[14]; 409 const auto transpose_top_right = Transpose(top_right); 410 for (int i = 0; i < 4; i++) { 411 vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]); 412 } 413 RegBlockInt32<4, 4> bottom_right; 414 bottom_right.buf.reg[0] = src.buf.reg[9]; 415 bottom_right.buf.reg[1] = src.buf.reg[11]; 416 bottom_right.buf.reg[2] = src.buf.reg[13]; 417 bottom_right.buf.reg[3] = src.buf.reg[15]; 418 const auto transpose_bottom_right = Transpose(bottom_right); 419 for (int i = 0; i < 4; i++) { 420 vst1q_s32(dst_ptr + (i + 4) * row_stride + 4, 421 transpose_bottom_right.buf.reg[i]); 422 } 423 } 424 } 425 }; 426 427 template <typename DstType> 428 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 429 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 430 int col) { 431 std::int32_t* dst_ptr = dst->data(row, col); 432 if (DstType::kOrder == MapOrder::ColMajor) { 433 vst1q_s32(dst_ptr, src.buf.reg[0]); 434 } else { 435 int row_stride = dst->rows_stride(); 436 vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 437 vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 438 vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 439 vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 440 } 441 } 442 }; 443 444 template <typename DstType> 445 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 446 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 447 int col) { 448 std::int32_t* dst_ptr = dst->data(row, col); 449 if (DstType::kOrder == MapOrder::RowMajor) { 450 vst1q_s32(dst_ptr, src.buf.reg[0]); 451 } else { 452 int col_stride = dst->cols_stride(); 453 vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); 454 vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); 455 vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); 456 vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); 457 } 458 } 459 }; 460 461 template <typename DstType> 462 struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> { 463 static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row, 464 int col) { 465 std::int16_t* dst_ptr = dst->data(row, col); 466 if (DstType::kOrder == MapOrder::RowMajor) { 467 vst1_s16(dst_ptr, src.buf.reg[0]); 468 } else { 469 int col_stride = dst->cols_stride(); 470 vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); 471 vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); 472 vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); 473 vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); 474 } 475 } 476 }; 477 478 template <typename DstType> 479 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 480 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 481 int col) { 482 const std::uint32_t src_reg = src.buf.reg[0]; 483 for (int i = 0; i < 4; i++) { 484 *dst->data(row + i, col) = (src_reg >> (8 * i)); 485 } 486 } 487 }; 488 489 template <typename DstType> 490 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 491 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 492 int col) { 493 for (int i = 0; i < 4; i++) { 494 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 495 } 496 } 497 }; 498 499 template <typename DstType> 500 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 501 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 502 int col) { 503 std::uint8_t* dst_ptr = dst->data(row, col); 504 if (DstType::kOrder == MapOrder::ColMajor) { 505 vst1_u8(dst_ptr, src.buf.reg[0]); 506 } else { 507 const int row_stride = dst->rows_stride(); 508 vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 509 vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 510 vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 511 vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 512 vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); 513 vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); 514 vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); 515 vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); 516 } 517 } 518 }; 519 520 template <typename DstType> 521 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 522 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 523 int col) { 524 std::uint8_t* dst_ptr = dst->data(row, col); 525 const int row_stride = dst->rows_stride(); 526 const int col_stride = dst->cols_stride(); 527 for (int i = 0; i < 2; i++) { 528 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, 529 src.buf.reg[i], 0); 530 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, 531 src.buf.reg[i], 1); 532 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, 533 src.buf.reg[i], 2); 534 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, 535 src.buf.reg[i], 3); 536 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, 537 src.buf.reg[i], 4); 538 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, 539 src.buf.reg[i], 5); 540 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, 541 src.buf.reg[i], 6); 542 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, 543 src.buf.reg[i], 7); 544 } 545 } 546 }; 547 548 template <typename DstType> 549 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 550 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 551 int col) { 552 std::uint8_t* dst_ptr = dst->data(row, col); 553 if (DstType::kOrder == MapOrder::ColMajor) { 554 int col_stride = dst->cols_stride(); 555 for (int i = 0; i < 4; i++) { 556 vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); 557 } 558 } else { 559 for (int i = 0; i < 4; i++) { 560 int row_stride = dst->rows_stride(); 561 std::uint8_t* col_ptr = dst_ptr + i; 562 vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); 563 vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); 564 vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); 565 vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); 566 vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); 567 vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); 568 vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); 569 vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); 570 } 571 } 572 } 573 }; 574 575 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { 576 uint8x8x2_t a[4]; 577 a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]); 578 a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]); 579 a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]); 580 a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]); 581 uint16x4x2_t b[4]; 582 b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]), 583 vreinterpret_u16_u8(a[1].val[0])); 584 b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]), 585 vreinterpret_u16_u8(a[1].val[1])); 586 b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]), 587 vreinterpret_u16_u8(a[3].val[0])); 588 b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]), 589 vreinterpret_u16_u8(a[3].val[1])); 590 uint32x2x2_t c[4]; 591 c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]), 592 vreinterpret_u32_u16(b[2].val[0])); 593 c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]), 594 vreinterpret_u32_u16(b[3].val[0])); 595 c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]), 596 vreinterpret_u32_u16(b[2].val[1])); 597 c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]), 598 vreinterpret_u32_u16(b[3].val[1])); 599 RegBlockUint8<8, 8> result; 600 result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]); 601 result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]); 602 result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]); 603 result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]); 604 result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]); 605 result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]); 606 result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]); 607 result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]); 608 return result; 609 } 610 611 template <typename DstType> 612 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 613 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 614 int col) { 615 const auto& block = 616 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 617 std::uint8_t* dst_ptr = dst->data(row, col); 618 int stride = dst->stride(); 619 for (int i = 0; i < 8; i++) { 620 vst1_u8(dst_ptr + i * stride, block.buf.reg[i]); 621 } 622 } 623 }; 624 625 template <typename DstType> 626 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { 627 static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, 628 int col) { 629 if (DstType::kOrder == MapOrder::ColMajor) { 630 vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]); 631 vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]); 632 vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]); 633 vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]); 634 vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]); 635 vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]); 636 vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]); 637 vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]); 638 } else { 639 int16x8x2_t a[4]; 640 a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]); 641 a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]); 642 a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]); 643 a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]); 644 int32x4x2_t b[4]; 645 b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]), 646 vreinterpretq_s32_s16(a[1].val[0])); 647 b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]), 648 vreinterpretq_s32_s16(a[1].val[1])); 649 b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]), 650 vreinterpretq_s32_s16(a[3].val[0])); 651 b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]), 652 vreinterpretq_s32_s16(a[3].val[1])); 653 vst1_s16(dst->data(row + 0, col + 0), 654 vget_low_s16(vreinterpretq_s16_s32(b[0].val[0]))); 655 vst1_s16(dst->data(row + 0, col + 4), 656 vget_low_s16(vreinterpretq_s16_s32(b[2].val[0]))); 657 vst1_s16(dst->data(row + 1, col + 0), 658 vget_low_s16(vreinterpretq_s16_s32(b[1].val[0]))); 659 vst1_s16(dst->data(row + 1, col + 4), 660 vget_low_s16(vreinterpretq_s16_s32(b[3].val[0]))); 661 vst1_s16(dst->data(row + 2, col + 0), 662 vget_low_s16(vreinterpretq_s16_s32(b[0].val[1]))); 663 vst1_s16(dst->data(row + 2, col + 4), 664 vget_low_s16(vreinterpretq_s16_s32(b[2].val[1]))); 665 vst1_s16(dst->data(row + 3, col + 0), 666 vget_low_s16(vreinterpretq_s16_s32(b[1].val[1]))); 667 vst1_s16(dst->data(row + 3, col + 4), 668 vget_low_s16(vreinterpretq_s16_s32(b[3].val[1]))); 669 vst1_s16(dst->data(row + 4, col + 0), 670 vget_high_s16(vreinterpretq_s16_s32(b[0].val[0]))); 671 vst1_s16(dst->data(row + 4, col + 4), 672 vget_high_s16(vreinterpretq_s16_s32(b[2].val[0]))); 673 vst1_s16(dst->data(row + 5, col + 0), 674 vget_high_s16(vreinterpretq_s16_s32(b[1].val[0]))); 675 vst1_s16(dst->data(row + 5, col + 4), 676 vget_high_s16(vreinterpretq_s16_s32(b[3].val[0]))); 677 vst1_s16(dst->data(row + 6, col + 0), 678 vget_high_s16(vreinterpretq_s16_s32(b[0].val[1]))); 679 vst1_s16(dst->data(row + 6, col + 4), 680 vget_high_s16(vreinterpretq_s16_s32(b[2].val[1]))); 681 vst1_s16(dst->data(row + 7, col + 0), 682 vget_high_s16(vreinterpretq_s16_s32(b[1].val[1]))); 683 vst1_s16(dst->data(row + 7, col + 4), 684 vget_high_s16(vreinterpretq_s16_s32(b[3].val[1]))); 685 } 686 } 687 }; 688 689 } // namespace gemmlowp 690 691 #endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 692