1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // pack_SSE.h: optimized SSE specializations of the templates in pack.h. 16 17 #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_ 18 #define GEMMLOWP_INTERNAL_PACK_SSE_H_ 19 20 #include <smmintrin.h> 21 #include "pack.h" 22 23 namespace gemmlowp { 24 25 // TODO: Add DepthMajorUint8SideMap 26 27 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> 28 WidthMajorUint8SideMap; 29 30 template <int Cells> 31 using WidthMajorSideFormatNCells4x2 = 32 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; 33 34 template <int Cells> 35 class PackingRegisterBlock< 36 WidthMajorUint8SideMap, 37 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > 38 : public PackingRegisterBlockBase< 39 WidthMajorUint8SideMap, 40 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > { 41 public: 42 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 43 typedef typename KernelSideFormat::Cell CellFormat; 44 static constexpr int kCells = KernelSideFormat::kCells; 45 static constexpr int kCellWidth = CellFormat::kWidth; 46 static constexpr int kKernelWidth = CellFormat::kWidth * kCells; 47 static constexpr int kCellDepth = CellFormat::kDepth; 48 static constexpr int kCellSize = CellFormat::kSize; 49 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)50 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 51 std::uint8_t* dst_ptr = dst->current_data(); 52 const int width_stride = this->complete_src_.width_stride(); 53 int depth_step = 8; 54 55 __m128i one = _mm_set1_epi16(1); 56 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; 57 cell_start_depth += depth_step) { 58 for (int cell_start_width = 0; cell_start_width < kKernelWidth; 59 cell_start_width += kCellWidth) { 60 std::int32_t* cell_sums_of_each_slice_ptr = 61 dst->sums_of_each_slice() + start_width + cell_start_width; 62 const std::uint8_t* src_data = 63 this->complete_src_.data(cell_start_width, cell_start_depth); 64 65 __m128i xmm1 = 66 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0])); 67 __m128i xmm2 = _mm_loadl_epi64( 68 reinterpret_cast<const __m128i*>(&src_data[1 * width_stride])); 69 __m128i xmm3 = _mm_loadl_epi64( 70 reinterpret_cast<const __m128i*>(&src_data[2 * width_stride])); 71 __m128i xmm4 = _mm_loadl_epi64( 72 reinterpret_cast<const __m128i*>(&src_data[3 * width_stride])); 73 74 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); 75 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); 76 77 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); 78 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); 79 80 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); 81 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); 82 83 _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9); 84 _mm_storel_epi64( 85 reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10); 86 87 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); 88 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); 89 90 _mm_storel_epi64( 91 reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]), 92 xmm11); 93 _mm_storel_epi64( 94 reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]), 95 xmm12); 96 97 xmm1 = _mm_cvtepu8_epi16(xmm9); 98 xmm2 = _mm_madd_epi16(xmm1, one); 99 __m128i sums_of_each_slice_xmm = _mm_loadu_si128( 100 reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0])); 101 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 102 103 xmm1 = _mm_cvtepu8_epi16(xmm10); 104 xmm2 = _mm_madd_epi16(xmm1, one); 105 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 106 107 xmm1 = _mm_cvtepu8_epi16(xmm11); 108 xmm2 = _mm_madd_epi16(xmm1, one); 109 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 110 111 xmm1 = _mm_cvtepu8_epi16(xmm12); 112 xmm2 = _mm_madd_epi16(xmm1, one); 113 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 114 115 _mm_storeu_si128( 116 reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]), 117 sums_of_each_slice_xmm); 118 dst_ptr += kCellSize; 119 } 120 dst_ptr += 3 * kCellSize * kCells; 121 } 122 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 123 } 124 }; 125 126 } // namespace gemmlowp 127 128 #endif // GEMMLOWP_INTERNAL_PACK_SSE_H_ 129