1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code 16 17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 19 20 #include "simd_wrappers.h" 21 22 namespace gemmlowp { 23 24 template <typename SrcScalarType, int N> 25 struct LoadImpl<RegBlockInt32<4, N>, 26 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 27 static RegBlockInt32<4, N> Run( 28 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 29 int col) { 30 RegBlockInt32<4, N> result; 31 for (int i = 0; i < N; i++) { 32 result.buf.reg[i] = LoadInt32x4(src.data(row, col + i)); 33 } 34 return result; 35 } 36 }; 37 38 template <typename SrcScalarType, int N> 39 struct LoadImpl<RegBlockInt32<8, N>, 40 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 41 static RegBlockInt32<8, N> Run( 42 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 43 int col) { 44 RegBlockInt32<8, N> result; 45 for (int i = 0; i < N; i++) { 46 result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i)); 47 result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i)); 48 } 49 return result; 50 } 51 }; 52 53 template <typename SrcScalarType> 54 struct LoadImpl<RegBlockInt32<1, 4>, 55 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 56 static RegBlockInt32<1, 4> Run( 57 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 58 int col) { 59 RegBlockInt32<1, 4> result; 60 std::int32_t buf[4]; 61 for (int i = 0; i < 4; i++) { 62 buf[i] = src(row, col + i); 63 } 64 result.buf.reg[0] = LoadInt32x4(buf); 65 return result; 66 } 67 }; 68 69 template <typename SrcScalarType> 70 struct LoadImpl<RegBlockInt32<1, 8>, 71 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 72 static RegBlockInt32<1, 8> Run( 73 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 74 int col) { 75 RegBlockInt32<1, 8> result; 76 std::int32_t buf[8]; 77 for (int i = 0; i < 8; i++) { 78 buf[i] = src(row, col + i); 79 } 80 result.buf.reg[0] = LoadInt32x4(buf); 81 result.buf.reg[1] = LoadInt32x4(buf + 4); 82 return result; 83 } 84 }; 85 86 template <typename SrcScalarType> 87 struct LoadImpl<RegBlockInt32<4, 1>, 88 VectorMap<SrcScalarType, VectorShape::Col>> { 89 static RegBlockInt32<4, 1> Run( 90 const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) { 91 RegBlockInt32<4, 1> result; 92 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 93 return result; 94 } 95 }; 96 97 template <typename SrcScalarType> 98 struct LoadImpl<RegBlockInt32<4, 1>, 99 VectorDup<SrcScalarType, VectorShape::Col>> { 100 static RegBlockInt32<4, 1> Run( 101 const VectorDup<SrcScalarType, VectorShape::Col>& src, int) { 102 RegBlockInt32<4, 1> result; 103 result.buf.reg[0] = LoadInt32x4(src(0)); 104 return result; 105 } 106 }; 107 108 template <typename SrcScalarType, int N> 109 struct LoadForBroadcastingImpl<RegBlockInt32<4, N>, 110 VectorMap<SrcScalarType, VectorShape::Col>> { 111 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 112 using RegisterBlockType = RegBlockInt32<4, N>; 113 using ResultBlockType = 114 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 115 SrcObjectType>::Type; 116 117 static ResultBlockType Run(const SrcObjectType& src, int pos) { 118 ResultBlockType result; 119 static_assert(ResultBlockType::kRegisterCount == 1, ""); 120 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 121 return result; 122 } 123 }; 124 125 template <typename SrcScalarType, int N> 126 struct LoadForBroadcastingImpl<RegBlockInt32<8, N>, 127 VectorMap<SrcScalarType, VectorShape::Col>> { 128 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 129 using RegisterBlockType = RegBlockInt32<8, N>; 130 using ResultBlockType = 131 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 132 SrcObjectType>::Type; 133 134 static ResultBlockType Run(const SrcObjectType& src, int pos) { 135 ResultBlockType result; 136 static_assert(ResultBlockType::kRegisterCount == 2, ""); 137 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 138 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 139 return result; 140 } 141 }; 142 143 template <typename SrcScalarType> 144 struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>, 145 VectorMap<SrcScalarType, VectorShape::Row>> { 146 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 147 using RegisterBlockType = RegBlockInt32<4, 1>; 148 using ResultBlockType = 149 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 150 SrcObjectType>::Type; 151 152 static ResultBlockType Run(const SrcObjectType& src, int pos) { 153 ResultBlockType result; 154 result.buf.reg[0] = src(pos); 155 return result; 156 } 157 }; 158 159 template <typename SrcScalarType, int N> 160 struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>, 161 VectorMap<SrcScalarType, VectorShape::Row>> { 162 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 163 using RegisterBlockType = RegBlockInt32<N, 4>; 164 using ResultBlockType = 165 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 166 SrcObjectType>::Type; 167 168 static ResultBlockType Run(const SrcObjectType& src, int pos) { 169 ResultBlockType result; 170 static_assert(ResultBlockType::kRegisterCount == 1, ""); 171 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 172 return result; 173 } 174 }; 175 176 template <typename SrcScalarType, int N> 177 struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>, 178 VectorMap<SrcScalarType, VectorShape::Row>> { 179 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 180 using RegisterBlockType = RegBlockInt32<N, 8>; 181 using ResultBlockType = 182 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 183 SrcObjectType>::Type; 184 185 static ResultBlockType Run(const SrcObjectType& src, int pos) { 186 ResultBlockType result; 187 static_assert(ResultBlockType::kRegisterCount == 2, ""); 188 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 189 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 190 return result; 191 } 192 }; 193 194 // 4x1 := 4x1 + 1x1 195 template <> 196 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 197 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 198 const RegBlockInt32<1, 1>& rhs) { 199 RegBlockInt32<4, 1> result; 200 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 201 return result; 202 } 203 }; 204 205 // 1x4 := 1x4 + 1x1 206 template <> 207 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 208 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 209 const RegBlockInt32<1, 1>& rhs) { 210 RegBlockInt32<1, 4> result; 211 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 212 return result; 213 } 214 }; 215 216 // 4x1 := 4x1 + 4x1 217 template <> 218 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 219 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 220 const RegBlockInt32<4, 1>& rhs) { 221 RegBlockInt32<4, 1> result; 222 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 223 return result; 224 } 225 }; 226 227 // 1x4 := 1x4 + 1x4 228 template <> 229 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 230 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 231 const RegBlockInt32<1, 4>& rhs) { 232 RegBlockInt32<1, 4> result; 233 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 234 return result; 235 } 236 }; 237 238 // 4x4 := 4x4 + 1x4 239 template <> 240 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 241 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 242 const RegBlockInt32<1, 4>& rhs) { 243 RegBlockInt32<4, 4> result; 244 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 245 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 246 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 247 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 248 return result; 249 } 250 }; 251 252 // 4x4 := 4x4 + 4x1 253 template <> 254 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 255 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 256 const RegBlockInt32<4, 1>& rhs) { 257 RegBlockInt32<4, 4> result; 258 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 259 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]); 260 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 261 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]); 262 return result; 263 } 264 }; 265 266 // 8x1 := 8x1 + 1x1 267 template <> 268 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 269 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 270 const RegBlockInt32<1, 1>& rhs) { 271 RegBlockInt32<8, 1> result; 272 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 273 for (int i = 0; i < 2; i++) { 274 result.buf.reg[i] = Add(lhs.buf.reg[i], p); 275 } 276 return result; 277 } 278 }; 279 280 // 8x1 := 8x1 + 8x1 281 template <> 282 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 283 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 284 const RegBlockInt32<8, 1>& rhs) { 285 RegBlockInt32<8, 1> result; 286 for (int i = 0; i < 2; i++) { 287 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); 288 } 289 return result; 290 } 291 }; 292 293 // 8x4 := 8x4 + 1x4 294 template <> 295 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 296 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 297 const RegBlockInt32<1, 4>& rhs) { 298 RegBlockInt32<8, 4> result; 299 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 300 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 301 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 302 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 303 result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 304 result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 305 result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 306 result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 307 return result; 308 } 309 }; 310 311 // 8x4 := 8x4 + 8x1 312 template <> 313 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 314 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 315 const RegBlockInt32<8, 1>& rhs) { 316 RegBlockInt32<8, 4> result; 317 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 318 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 319 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 320 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]); 321 result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]); 322 result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]); 323 result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]); 324 result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]); 325 return result; 326 } 327 }; 328 329 // 1x8 := 1x8 + 1x8 330 template <> 331 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { 332 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 333 const RegBlockInt32<1, 8>& rhs) { 334 RegBlockInt32<1, 8> result; 335 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 336 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 337 return result; 338 } 339 }; 340 341 // 1x8 := 1x8 + 1x1 342 template <> 343 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { 344 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 345 const RegBlockInt32<1, 1>& rhs) { 346 RegBlockInt32<1, 8> result; 347 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 348 result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 349 return result; 350 } 351 }; 352 353 // 4x1 := 4x1 * 1x1 354 template <> 355 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 356 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 357 const RegBlockInt32<1, 1>& rhs) { 358 RegBlockInt32<4, 1> result; 359 result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 360 return result; 361 } 362 }; 363 364 // 4x1 := 4x1 * 4x1 365 template <> 366 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 367 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 368 const RegBlockInt32<4, 1>& rhs) { 369 RegBlockInt32<4, 1> result; 370 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 371 return result; 372 } 373 }; 374 375 // 1x4 := 1x4 * 1x4 376 template <> 377 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 378 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 379 const RegBlockInt32<1, 4>& rhs) { 380 RegBlockInt32<1, 4> result; 381 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 382 return result; 383 } 384 }; 385 386 // 1x4 := 1x4 * 1x1 387 template <> 388 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 389 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 390 const RegBlockInt32<1, 1>& rhs) { 391 RegBlockInt32<1, 4> result; 392 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 393 return result; 394 } 395 }; 396 397 // 4x4 := 4x4 * 1x4 398 template <> 399 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 400 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 401 const RegBlockInt32<1, 4>& rhs) { 402 RegBlockInt32<4, 4> result; 403 const Int32x4 p = rhs.buf.reg[0]; 404 result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p); 405 result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p); 406 result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p); 407 result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p); 408 return result; 409 } 410 }; 411 412 // 4x4 := 4x4 * 4x1 413 template <> 414 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 415 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 416 const RegBlockInt32<4, 1>& rhs) { 417 RegBlockInt32<4, 4> result; 418 const Int32x4 p = rhs.buf.reg[0]; 419 result.buf.reg[0] = Mul(lhs.buf.reg[0], p); 420 result.buf.reg[1] = Mul(lhs.buf.reg[1], p); 421 result.buf.reg[2] = Mul(lhs.buf.reg[2], p); 422 result.buf.reg[3] = Mul(lhs.buf.reg[3], p); 423 return result; 424 } 425 }; 426 427 // 8x1 := 8x1 * 1x1 428 template <> 429 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 430 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 431 const RegBlockInt32<1, 1>& rhs) { 432 RegBlockInt32<8, 1> result; 433 const std::int32_t p = rhs.buf.reg[0]; 434 for (int i = 0; i < 2; i++) { 435 result.buf.reg[i] = Mul(lhs.buf.reg[i], p); 436 } 437 return result; 438 } 439 }; 440 441 // 8x1 := 8x1 * 8x1 442 template <> 443 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 444 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 445 const RegBlockInt32<8, 1>& rhs) { 446 RegBlockInt32<8, 1> result; 447 for (int i = 0; i < 2; i++) { 448 result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]); 449 } 450 return result; 451 } 452 }; 453 454 // 8x4 := 8x4 * 1x4 455 template <> 456 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 457 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 458 const RegBlockInt32<1, 4>& rhs) { 459 RegBlockInt32<8, 4> result; 460 const Int32x4 p = rhs.buf.reg[0]; 461 for (int i = 0; i < 2; i++) { 462 result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p); 463 result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p); 464 result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p); 465 result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p); 466 } 467 return result; 468 } 469 }; 470 471 // 8x4 := 8x4 * 8x1 472 template <> 473 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 474 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 475 const RegBlockInt32<8, 1>& rhs) { 476 RegBlockInt32<8, 4> result; 477 const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]}; 478 for (int i = 0; i < 4; i++) { 479 for (int j = 0; j < 2; j++) { 480 const int k = j + 2 * i; 481 result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]); 482 } 483 } 484 return result; 485 } 486 }; 487 488 // Rx1 += Rx1 * 1x1 489 template <int Rows> 490 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 491 RegBlockInt32<Rows, 1>> { 492 static void Run(const RegBlockInt32<Rows, 1>& lhs, 493 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) { 494 const std::int32_t p = rhs.buf.reg[0]; 495 for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) { 496 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 497 } 498 } 499 }; 500 501 // RxC += Rx1 * 1x1 502 template <int Rows, int Cols> 503 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 504 RegBlockInt32<Rows, Cols>> { 505 static void Run(const RegBlockInt32<Rows, 1>& lhs, 506 const RegBlockInt32<1, 1>& rhs, 507 RegBlockInt32<Rows, Cols>* acc) { 508 const std::int32_t p = rhs.buf.reg[0]; 509 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 510 for (int i = 0; i < kRegsPerCol; i++) { 511 const Int32x4 q = Mul(lhs.buf.reg[i], p); 512 for (int j = 0; j < Cols; j++) { 513 acc->buf.reg[i + j * kRegsPerCol] = 514 Add(acc->buf.reg[i + j * kRegsPerCol], q); 515 } 516 } 517 } 518 }; 519 520 // 1xC += 1xC * 1x1 521 template <int Cols> 522 struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>, 523 RegBlockInt32<1, Cols>> { 524 static void Run(const RegBlockInt32<1, Cols>& lhs, 525 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 526 const std::int32_t p = rhs.buf.reg[0]; 527 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 528 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 529 } 530 } 531 }; 532 533 // RxC += 1x1 * 1x1 534 template <int Rows, int Cols> 535 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 536 RegBlockInt32<Rows, Cols>> { 537 static void Run(const RegBlockInt32<1, 1>& lhs, 538 const RegBlockInt32<1, 1>& rhs, 539 RegBlockInt32<Rows, Cols>* acc) { 540 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 541 for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) { 542 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 543 } 544 } 545 }; 546 547 // 1x1 += 1x1 * 1x1 548 template <> 549 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 550 RegBlockInt32<1, 1>> { 551 static void Run(const RegBlockInt32<1, 1>& lhs, 552 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) { 553 MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]); 554 } 555 }; 556 557 // Rx4 += Rx1 * 1x4 558 template <int Rows> 559 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>, 560 RegBlockInt32<Rows, 4>> { 561 static void Run(const RegBlockInt32<Rows, 1>& lhs, 562 const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) { 563 const Int32x4 p = rhs.buf.reg[0]; 564 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 565 for (int i = 0; i < kRegsPerCol; i++) { 566 MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]); 567 MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]); 568 MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]); 569 MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]); 570 } 571 } 572 }; 573 574 // Rx4 += 1x4 * 1x1 575 template <int Rows> 576 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 577 RegBlockInt32<Rows, 4>> { 578 static void Run(const RegBlockInt32<1, 4>& lhs, 579 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) { 580 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 581 Int32x4 q[4]; 582 q[0] = DupLane<0>(p); 583 q[1] = DupLane<1>(p); 584 q[2] = DupLane<2>(p); 585 q[3] = DupLane<3>(p); 586 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 587 for (int i = 0; i < kRegsPerCol; i++) { 588 for (int j = 0; j < 4; j++) { 589 acc->buf.reg[i + j * kRegsPerCol] = 590 Add(q[j], acc->buf.reg[i + j * kRegsPerCol]); 591 } 592 } 593 } 594 }; 595 596 // 1xC += 1x1 * 1x1 597 template <int Cols> 598 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 599 RegBlockInt32<1, Cols>> { 600 static void Run(const RegBlockInt32<1, 1>& lhs, 601 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 602 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 603 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 604 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 605 } 606 } 607 }; 608 609 // 1x4 += 1x4 * 1x1 610 template <> 611 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 612 RegBlockInt32<1, 4>> { 613 static void Run(const RegBlockInt32<1, 4>& lhs, 614 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) { 615 const std::int32_t p = rhs.buf.reg[0]; 616 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 617 } 618 }; 619 620 // 4xC += 4x1 * 1x1 621 template <int Cols> 622 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 623 RegBlockInt32<4, Cols>> { 624 static void Run(const RegBlockInt32<4, 1>& lhs, 625 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) { 626 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 627 for (int i = 0; i < Cols; i++) { 628 acc->buf.reg[i] = Add(p, acc->buf.reg[i]); 629 } 630 } 631 }; 632 633 // 4x1 += 4x1 * 1x1 634 template <> 635 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 636 RegBlockInt32<4, 1>> { 637 static void Run(const RegBlockInt32<4, 1>& lhs, 638 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) { 639 const std::int32_t p = rhs.buf.reg[0]; 640 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 641 } 642 }; 643 644 } // namespace gemmlowp 645 646 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 647