1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // pack_msa.h: optimized MSA specializations of the templates in pack.h. 16 17 #ifndef GEMMLOWP_INTERNAL_PACK_MSA_H_ 18 #define GEMMLOWP_INTERNAL_PACK_MSA_H_ 19 20 #include "pack.h" 21 22 #include <msa.h> 23 24 namespace gemmlowp { 25 26 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> 27 WidthMajorUint8SideMap; 28 29 template <int Cells> 30 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>; 31 32 template <int Cells> 33 class PackingRegisterBlock< 34 WidthMajorUint8SideMap, 35 PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> 36 : public PackingRegisterBlockBase< 37 WidthMajorUint8SideMap, 38 PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> { 39 public: 40 typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 41 typedef typename KernelSideFormat::Cell CellFormat; 42 static constexpr int kCells = KernelSideFormat::kCells; 43 static const int kCellWidth = CellFormat::kWidth; 44 static const int kKernelWidth = CellFormat::kWidth * kCells; 45 static const int kCellDepth = CellFormat::kDepth; 46 static const int kCellSize = CellFormat::kSize; 47 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)48 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 49 std::uint8_t* dst_ptr = dst->current_data(); 50 const std::uint8_t* const src_ptr = this->complete_src_.data(); 51 const int stride = this->complete_src_.stride(); 52 // Load source WidthMajor data 53 v16i8 src_lines[4 * kCells]; 54 for (int i = 0; i < 4 * kCells; i++) { 55 src_lines[i] = __builtin_msa_ld_b( 56 const_cast<std::uint8_t*>(src_ptr + i * stride), 0); 57 } 58 // Reorder the data within registers to make DepthMajor 4x2 cells 59 v16i8 src_lines_intertwined_2x[2 * kCells][2]; 60 for (int i = 0; i < kCells; i++) { 61 src_lines_intertwined_2x[2 * i][0] = 62 __builtin_msa_ilvr_b(src_lines[4 * i + 2], src_lines[4 * i]); 63 src_lines_intertwined_2x[2 * i][1] = 64 __builtin_msa_ilvl_b(src_lines[4 * i + 2], src_lines[4 * i]); 65 src_lines_intertwined_2x[2 * i + 1][0] = 66 __builtin_msa_ilvr_b(src_lines[4 * i + 3], src_lines[4 * i + 1]); 67 src_lines_intertwined_2x[2 * i + 1][1] = 68 __builtin_msa_ilvl_b(src_lines[4 * i + 3], src_lines[4 * i + 1]); 69 } 70 v16i8 src_lines_intertwined_4x[2 * kCells][2]; 71 for (int i = 0; i < kCells; i++) { 72 src_lines_intertwined_4x[2 * i][0] = 73 __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][0], 74 src_lines_intertwined_2x[2 * i][0]); 75 src_lines_intertwined_4x[2 * i][1] = 76 __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][0], 77 src_lines_intertwined_2x[2 * i][0]); 78 src_lines_intertwined_4x[2 * i + 1][0] = 79 __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][1], 80 src_lines_intertwined_2x[2 * i][1]); 81 src_lines_intertwined_4x[2 * i + 1][1] = 82 __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][1], 83 src_lines_intertwined_2x[2 * i][1]); 84 } 85 // Store the resulting DepthMajor 4x2 cells in the destination packed block 86 for (int outer = 0; outer < 2; outer++) { 87 for (int inner = 0; inner < 2; inner++) { 88 if (kCells % 2 == 0) { 89 for (int cell = 0; cell < kCells; cell += 2) { 90 v2i64 tmp = __builtin_msa_ilvr_d( 91 reinterpret_cast<v2i64>( 92 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 93 reinterpret_cast<v2i64>( 94 src_lines_intertwined_4x[2 * cell + outer][inner])); 95 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 96 dst_ptr += 16; 97 } 98 for (int cell = 0; cell < kCells; cell += 2) { 99 v2i64 tmp = __builtin_msa_ilvl_d( 100 reinterpret_cast<v2i64>( 101 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 102 reinterpret_cast<v2i64>( 103 src_lines_intertwined_4x[2 * cell + outer][inner])); 104 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 105 dst_ptr += 16; 106 } 107 } else { 108 // Store even number of low vector halves. 109 for (int cell = 0; cell < kCells - 1; cell += 2) { 110 v2i64 tmp = __builtin_msa_ilvr_d( 111 reinterpret_cast<v2i64>( 112 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 113 reinterpret_cast<v2i64>( 114 src_lines_intertwined_4x[2 * cell + outer][inner])); 115 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 116 dst_ptr += 16; 117 } 118 // Store last low half and first high half. 119 v2i64 tmp = reinterpret_cast<v2i64>( 120 src_lines_intertwined_4x[2 * 0 + outer][inner]); 121 tmp = __builtin_msa_insve_d( 122 tmp, 0, 123 reinterpret_cast<v2i64>( 124 src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner])); 125 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 126 dst_ptr += 16; 127 // Store even number of high vector halves. 128 for (int cell = 1; cell < kCells; cell += 2) { 129 v2i64 tmp = __builtin_msa_ilvl_d( 130 reinterpret_cast<v2i64>( 131 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 132 reinterpret_cast<v2i64>( 133 src_lines_intertwined_4x[2 * cell + outer][inner])); 134 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 135 dst_ptr += 16; 136 } 137 } 138 } 139 } 140 // Compute sums across the depth dimension 141 v8i16 sums_of_2_cells[kCells][4]; 142 const v16i8 zeroes = __builtin_msa_ldi_b(0); 143 for (int outer = 0; outer < 2; outer++) { 144 for (int inner = 0; inner < 2; inner++) { 145 int i = 2 * outer + inner; 146 for (int cell = 0; cell < kCells; cell++) { 147 v8i16 tmp0 = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b( 148 zeroes, src_lines_intertwined_4x[2 * cell + outer][inner])); 149 v8i16 tmp1 = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b( 150 zeroes, src_lines_intertwined_4x[2 * cell + outer][inner])); 151 sums_of_2_cells[cell][i] = __builtin_msa_addv_h(tmp0, tmp1); 152 } 153 } 154 } 155 v4i32 sums_of_4_cells[kCells][4]; 156 for (int i = 0; i < 4; i++) { 157 for (int cell = 0; cell < kCells; cell++) { 158 v4i32 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h( 159 reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i])); 160 v4i32 tmp1 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h( 161 reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i])); 162 sums_of_4_cells[cell][i] = __builtin_msa_addv_w(tmp0, tmp1); 163 } 164 } 165 // Update the sums_of_each_slice vector 166 for (int cell = 0; cell < kCells; cell++) { 167 v4i32 s01 = __builtin_msa_addv_w(sums_of_4_cells[cell][0], 168 sums_of_4_cells[cell][1]); 169 v4i32 s23 = __builtin_msa_addv_w(sums_of_4_cells[cell][2], 170 sums_of_4_cells[cell][3]); 171 v4i32 s = __builtin_msa_addv_w(s01, s23); 172 std::int32_t* sums_of_each_slice_ptr = 173 dst->sums_of_each_slice() + start_width + 4 * cell; 174 v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0); 175 tmp = __builtin_msa_addv_w(tmp, s); 176 __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0); 177 } 178 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 179 } 180 }; 181 182 template <int Cells> 183 using WidthMajorSideFormatNCells4x2 = 184 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; 185 186 template <int Cells> 187 class PackingRegisterBlock< 188 WidthMajorUint8SideMap, 189 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> 190 : public PackingRegisterBlockBase< 191 WidthMajorUint8SideMap, 192 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { 193 public: 194 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 195 typedef typename KernelSideFormat::Cell CellFormat; 196 static constexpr int kCells = KernelSideFormat::kCells; 197 static const int kCellWidth = CellFormat::kWidth; 198 static const int kKernelWidth = CellFormat::kWidth * kCells; 199 static const int kCellDepth = CellFormat::kDepth; 200 static const int kCellSize = CellFormat::kSize; 201 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)202 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 203 std::uint8_t* dst_ptr = dst->current_data(); 204 const std::uint8_t* src_ptr = this->complete_src_.data(); 205 const int stride = this->complete_src_.stride(); 206 // Load source WidthMajor data 207 v8i16 src_lines[kCells * 4]; 208 for (int i = 0; i < kCells; i++) { 209 #define GEMMLOWP_UNROLLED_LOOP_ITER(k) \ 210 src_lines[4 * i + k] = \ 211 __builtin_msa_ld_h(const_cast<std::uint8_t*>(src_ptr), 0); \ 212 src_ptr += stride; 213 214 GEMMLOWP_UNROLLED_LOOP_ITER(0) 215 GEMMLOWP_UNROLLED_LOOP_ITER(1) 216 GEMMLOWP_UNROLLED_LOOP_ITER(2) 217 GEMMLOWP_UNROLLED_LOOP_ITER(3) 218 219 #undef GEMMLOWP_UNROLLED_LOOP_ITER 220 } 221 // Reorder the data within registers to make WidthMajor 4x2 cells 222 v8i16 src_lines_intertwined_2x[2 * kCells][2]; 223 for (int i = 0; i < kCells; i++) { 224 src_lines_intertwined_2x[2 * i][0] = 225 __builtin_msa_ilvr_h(src_lines[4 * i + 2], src_lines[4 * i]); 226 src_lines_intertwined_2x[2 * i][1] = 227 __builtin_msa_ilvl_h(src_lines[4 * i + 2], src_lines[4 * i]); 228 src_lines_intertwined_2x[2 * i + 1][0] = 229 __builtin_msa_ilvr_h(src_lines[4 * i + 3], src_lines[4 * i + 1]); 230 src_lines_intertwined_2x[2 * i + 1][1] = 231 __builtin_msa_ilvl_h(src_lines[4 * i + 3], src_lines[4 * i + 1]); 232 } 233 v8i16 src_lines_intertwined_4x[2 * kCells][2]; 234 for (int i = 0; i < kCells; i++) { 235 src_lines_intertwined_4x[2 * i][0] = 236 __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][0], 237 src_lines_intertwined_2x[2 * i][0]); 238 src_lines_intertwined_4x[2 * i][1] = 239 __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][0], 240 src_lines_intertwined_2x[2 * i][0]); 241 src_lines_intertwined_4x[2 * i + 1][0] = 242 __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][1], 243 src_lines_intertwined_2x[2 * i][1]); 244 src_lines_intertwined_4x[2 * i + 1][1] = 245 __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][1], 246 src_lines_intertwined_2x[2 * i][1]); 247 } 248 // Store the resulting WidthMajor 4x2 cells in the destination packed block 249 for (int outer = 0; outer < 2; outer++) { 250 for (int inner = 0; inner < 2; inner++) { 251 if (kCells % 2 == 0) { 252 for (int cell = 0; cell < kCells; cell += 2) { 253 v2i64 tmp = __builtin_msa_ilvr_d( 254 reinterpret_cast<v2i64>( 255 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 256 reinterpret_cast<v2i64>( 257 src_lines_intertwined_4x[2 * cell + outer][inner])); 258 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 259 dst_ptr += 16; 260 } 261 for (int cell = 0; cell < kCells; cell += 2) { 262 v2i64 tmp = __builtin_msa_ilvl_d( 263 reinterpret_cast<v2i64>( 264 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 265 reinterpret_cast<v2i64>( 266 src_lines_intertwined_4x[2 * cell + outer][inner])); 267 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 268 dst_ptr += 16; 269 } 270 } else { 271 // Store even number of low vector halves. 272 for (int cell = 0; cell < kCells - 1; cell += 2) { 273 v2i64 tmp = __builtin_msa_ilvr_d( 274 reinterpret_cast<v2i64>( 275 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 276 reinterpret_cast<v2i64>( 277 src_lines_intertwined_4x[2 * cell + outer][inner])); 278 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 279 dst_ptr += 16; 280 } 281 // Store last low half and first high half. 282 v2i64 tmp = reinterpret_cast<v2i64>( 283 src_lines_intertwined_4x[2 * 0 + outer][inner]); 284 tmp = __builtin_msa_insve_d( 285 tmp, 0, 286 reinterpret_cast<v2i64>( 287 src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner])); 288 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 289 dst_ptr += 16; 290 // Store even number of high vector halves. 291 for (int cell = 1; cell < kCells; cell += 2) { 292 v2i64 tmp = __builtin_msa_ilvl_d( 293 reinterpret_cast<v2i64>( 294 src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]), 295 reinterpret_cast<v2i64>( 296 src_lines_intertwined_4x[2 * cell + outer][inner])); 297 __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0); 298 dst_ptr += 16; 299 } 300 } 301 } 302 } 303 // Compute sums across the depth dimension 304 v8i16 sums_of_2[kCells][4]; 305 for (int outer = 0; outer < 2; outer++) { 306 for (int inner = 0; inner < 2; inner++) { 307 int i = 2 * outer + inner; 308 for (int cell = 0; cell < kCells; cell++) { 309 sums_of_2[cell][i] = reinterpret_cast<v8i16>(__builtin_msa_hadd_u_h( 310 reinterpret_cast<v16u8>( 311 src_lines_intertwined_4x[2 * cell + outer][inner]), 312 reinterpret_cast<v16u8>( 313 src_lines_intertwined_4x[2 * cell + outer][inner]))); 314 } 315 } 316 } 317 v8i16 sums_of_4[kCells][2]; 318 for (int i = 0; i < 2; i++) { 319 for (int cell = 0; cell < kCells; cell++) { 320 sums_of_4[cell][i] = __builtin_msa_addv_h(sums_of_2[cell][2 * i], 321 sums_of_2[cell][2 * i + 1]); 322 } 323 } 324 v8i16 sums_of_8[kCells]; 325 for (int cell = 0; cell < kCells; cell++) { 326 sums_of_8[cell] = 327 __builtin_msa_addv_h(sums_of_4[cell][0], sums_of_4[cell][1]); 328 } 329 330 v4i32 sums_of_16[kCells]; 331 const v8i16 zeroes = __builtin_msa_ldi_h(0); 332 for (int cell = 0; cell < kCells; cell++) { 333 sums_of_16[cell] = reinterpret_cast<v4i32>( 334 __builtin_msa_ilvr_h(zeroes, sums_of_8[cell])); 335 v8i16 tmp = __builtin_msa_ilvl_h(zeroes, sums_of_8[cell]); 336 sums_of_16[cell] = 337 __builtin_msa_addv_w(sums_of_16[cell], reinterpret_cast<v4i32>(tmp)); 338 } 339 // Update the sums_of_each_slice vector 340 for (int cell = 0; cell < kCells; cell++) { 341 std::int32_t* sums_of_each_slice_ptr = 342 dst->sums_of_each_slice() + start_width + 4 * cell; 343 v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0); 344 tmp = __builtin_msa_addv_w(tmp, sums_of_16[cell]); 345 __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0); 346 } 347 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 348 } 349 }; 350 351 template <int Width> 352 using Int8FastKernelFormat = 353 KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>; 354 355 template <int Width> 356 class PackingRegisterBlock<WidthMajorUint8SideMap, 357 PackedSideBlock<Int8FastKernelFormat<Width>>> 358 : public PackingRegisterBlockBase< 359 WidthMajorUint8SideMap, 360 PackedSideBlock<Int8FastKernelFormat<Width>>> { 361 public: 362 static_assert(Width == 2 || Width == 4, ""); 363 typedef Int8FastKernelFormat<Width> KernelSideFormat; 364 typedef typename KernelSideFormat::Cell CellFormat; 365 static const int kCells = KernelSideFormat::kCells; 366 static const int kCellWidth = CellFormat::kWidth; 367 static const int kKernelWidth = CellFormat::kWidth * kCells; 368 static const int kCellDepth = CellFormat::kDepth; 369 static const int kCellSize = CellFormat::kSize; 370 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)371 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 372 std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; 373 std::uint8_t* dst_ptr = dst->current_data(); 374 const std::uint8_t* const src_ptr = this->complete_src_.data(); 375 const int stride = this->complete_src_.stride(); 376 // Load source WidthMajor data. 377 v16i8 src_lines[Width]; 378 for (int i = 0; i < Width; i++) { 379 src_lines[i] = __builtin_msa_ld_b( 380 const_cast<std::uint8_t*>(src_ptr + i * stride), 0); 381 } 382 for (int i = 0; i < Width; i++) { 383 // Subtract 128 by inverting bit 7. 384 src_lines[i] = reinterpret_cast<v16i8>( 385 __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7)); 386 } 387 for (int i = 0; i < Width; i++) { 388 __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0); 389 } 390 v8i16 sums2[Width]; 391 for (int i = 0; i < Width; i++) { 392 sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]); 393 } 394 v4i32 sums4_wide[Width]; 395 for (int i = 0; i < Width; i++) { 396 sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]); 397 } 398 v8i16 sums4[Width / 2]; 399 for (int i = 0; i < Width / 2; i++) { 400 sums4[i] = __builtin_msa_pckev_h( 401 reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]), 402 reinterpret_cast<v8i16>(sums4_wide[2 * i])); 403 } 404 v4i32 sums8_wide[Width / 2]; 405 for (int i = 0; i < Width / 2; i++) { 406 sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]); 407 } 408 if (Width == 4) { 409 v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0); 410 v8i16 sums8 = __builtin_msa_pckev_h( 411 reinterpret_cast<v8i16>(sums8_wide[1]), 412 reinterpret_cast<v8i16>(sums8_wide[0])); 413 v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8); 414 sum = __builtin_msa_addv_w(sum, sums16); 415 __builtin_msa_st_w(sum, sums_ptr, 0); 416 } else { 417 assert(Width == 2); 418 std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] }; 419 v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]); 420 sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0); 421 sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2); 422 sums_ptr[0] = sum[0]; 423 sums_ptr[1] = sum[1]; 424 } 425 dst->seek_forward_n_cells(1); 426 } 427 }; 428 429 } // namespace gemmlowp 430 431 #endif // GEMMLOWP_INTERNAL_PACK_MSA_H_ 432