1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // pack_avx.h: optimized AVX specializations of the templates in pack.h. 16 17 #ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_ 18 #define GEMMLOWP_INTERNAL_PACK_AVX_H_ 19 20 #include <immintrin.h> 21 #include "pack.h" 22 23 namespace gemmlowp { 24 25 // TODO: Add DepthMajorUint8SideMap 26 27 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> 28 WidthMajorUint8SideMap; 29 30 template <int Cells> 31 using WidthMajorSideFormatNCells4x2 = 32 KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>; 33 34 template <int Cells> 35 class PackingRegisterBlock< 36 WidthMajorUint8SideMap, 37 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> 38 : public PackingRegisterBlockBase< 39 WidthMajorUint8SideMap, 40 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { 41 public: 42 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 43 typedef typename KernelSideFormat::Cell CellFormat; 44 static const int kCells = KernelSideFormat::kCells; 45 static const int kCellWidth = CellFormat::kWidth; 46 static const int kKernelWidth = CellFormat::kWidth * kCells; 47 static const int kCellDepth = CellFormat::kDepth; 48 static const int kCellSize = CellFormat::kSize; 49 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)50 void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { 51 std::uint8_t *dst_ptr = dst->current_data(); 52 const int width_stride = this->complete_src_.width_stride(); 53 int depth_step = 16; 54 55 __m256i one = _mm256_set1_epi16(1); 56 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; 57 cell_start_depth += depth_step) { 58 for (int cell_start_width = 0; cell_start_width < kKernelWidth; 59 cell_start_width += kCellWidth) { 60 std::int32_t *cell_sums_of_each_slice_ptr = 61 dst->sums_of_each_slice() + start_width + cell_start_width; 62 const std::uint8_t *src_data = 63 this->complete_src_.data(cell_start_width, cell_start_depth); 64 65 __m128i xmm1 = 66 _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0])); 67 __m128i xmm2 = _mm_loadu_si128( 68 reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); 69 __m128i xmm3 = _mm_loadu_si128( 70 reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); 71 __m128i xmm4 = _mm_loadu_si128( 72 reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); 73 __m128i xmm5 = _mm_loadu_si128( 74 reinterpret_cast<const __m128i *>(&src_data[4 * width_stride])); 75 __m128i xmm6 = _mm_loadu_si128( 76 reinterpret_cast<const __m128i *>(&src_data[5 * width_stride])); 77 __m128i xmm7 = _mm_loadu_si128( 78 reinterpret_cast<const __m128i *>(&src_data[6 * width_stride])); 79 __m128i xmm8 = _mm_loadu_si128( 80 reinterpret_cast<const __m128i *>(&src_data[7 * width_stride])); 81 82 __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1); 83 __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2); 84 __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3); 85 __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4); 86 87 __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2); 88 __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4); 89 90 __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2); 91 __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4); 92 93 __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6); 94 __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6); 95 96 __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10); 97 __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10); 98 99 __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8); 100 __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8); 101 102 __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8); 103 __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8); 104 105 __m128i xmm9 = _mm256_castsi256_si128(ymm11); 106 __m128i xmm10 = _mm256_castsi256_si128(ymm12); 107 __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1); 108 __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1); 109 110 xmm1 = _mm256_castsi256_si128(ymm15); 111 xmm2 = _mm256_castsi256_si128(ymm16); 112 xmm3 = _mm256_extracti128_si256(ymm15, 1); 113 xmm4 = _mm256_extracti128_si256(ymm16, 1); 114 115 _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); 116 _mm_storeu_si128( 117 reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11); 118 _mm_storeu_si128( 119 reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), 120 xmm10); 121 _mm_storeu_si128( 122 reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), 123 xmm12); 124 _mm_storeu_si128( 125 reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]), 126 xmm1); 127 _mm_storeu_si128( 128 reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]), 129 xmm3); 130 131 _mm_storeu_si128( 132 reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]), 133 xmm2); 134 _mm_storeu_si128( 135 reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]), 136 xmm4); 137 138 ymm6 = _mm256_cvtepu8_epi16(xmm9); 139 ymm7 = _mm256_madd_epi16(ymm6, one); 140 __m256i sums_of_each_slice_xmm = _mm256_loadu_si256( 141 reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0])); 142 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 143 144 ymm6 = _mm256_cvtepu8_epi16(xmm11); 145 ymm7 = _mm256_madd_epi16(ymm6, one); 146 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 147 148 ymm6 = _mm256_cvtepu8_epi16(xmm10); 149 ymm7 = _mm256_madd_epi16(ymm6, one); 150 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 151 152 ymm6 = _mm256_cvtepu8_epi16(xmm12); 153 ymm7 = _mm256_madd_epi16(ymm6, one); 154 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 155 156 ymm6 = _mm256_cvtepu8_epi16(xmm1); 157 ymm7 = _mm256_madd_epi16(ymm6, one); 158 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 159 160 ymm6 = _mm256_cvtepu8_epi16(xmm3); 161 ymm7 = _mm256_madd_epi16(ymm6, one); 162 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 163 164 ymm6 = _mm256_cvtepu8_epi16(xmm2); 165 ymm7 = _mm256_madd_epi16(ymm6, one); 166 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 167 168 ymm6 = _mm256_cvtepu8_epi16(xmm4); 169 ymm7 = _mm256_madd_epi16(ymm6, one); 170 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); 171 172 _mm256_storeu_si256( 173 reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]), 174 sums_of_each_slice_xmm); 175 dst_ptr += kCellSize; 176 } 177 dst_ptr += 7 * kCellSize * kCells; 178 } 179 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 180 } 181 }; 182 183 // Pack format for 4x2 rhs format 184 template <int Cells> 185 using RhsWidthMajorSideFormatNCells4x2 = 186 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; 187 188 template <int Cells> 189 class PackingRegisterBlock< 190 WidthMajorUint8SideMap, 191 PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> 192 : public PackingRegisterBlockBase< 193 WidthMajorUint8SideMap, 194 PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> { 195 public: 196 typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 197 typedef typename KernelSideFormat::Cell CellFormat; 198 static const int kCells = KernelSideFormat::kCells; 199 static const int kCellWidth = CellFormat::kWidth; 200 static const int kKernelWidth = CellFormat::kWidth * kCells; 201 static const int kCellDepth = CellFormat::kDepth; 202 static const int kCellSize = CellFormat::kSize; 203 Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)204 void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { 205 std::uint8_t *dst_ptr = dst->current_data(); 206 const int width_stride = this->complete_src_.width_stride(); 207 int depth_step = 8; 208 209 __m128i one = _mm_set1_epi16(1); 210 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; 211 cell_start_depth += depth_step) { 212 for (int cell_start_width = 0; cell_start_width < kKernelWidth; 213 cell_start_width += kCellWidth) { 214 std::int32_t *cell_sums_of_each_slice_ptr = 215 dst->sums_of_each_slice() + start_width + cell_start_width; 216 const std::uint8_t *src_data = 217 this->complete_src_.data(cell_start_width, cell_start_depth); 218 219 __m128i xmm1 = 220 _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0])); 221 __m128i xmm2 = _mm_loadl_epi64( 222 reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); 223 __m128i xmm3 = _mm_loadl_epi64( 224 reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); 225 __m128i xmm4 = _mm_loadl_epi64( 226 reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); 227 228 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); 229 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); 230 231 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); 232 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); 233 234 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); 235 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); 236 237 _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); 238 _mm_storel_epi64( 239 reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10); 240 241 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); 242 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); 243 244 _mm_storel_epi64( 245 reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), 246 xmm11); 247 _mm_storel_epi64( 248 reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), 249 xmm12); 250 251 xmm1 = _mm_cvtepu8_epi16(xmm9); 252 xmm2 = _mm_madd_epi16(xmm1, one); 253 __m128i sums_of_each_slice_xmm = _mm_loadu_si128( 254 reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0])); 255 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 256 257 xmm1 = _mm_cvtepu8_epi16(xmm10); 258 xmm2 = _mm_madd_epi16(xmm1, one); 259 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 260 261 xmm1 = _mm_cvtepu8_epi16(xmm11); 262 xmm2 = _mm_madd_epi16(xmm1, one); 263 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 264 265 xmm1 = _mm_cvtepu8_epi16(xmm12); 266 xmm2 = _mm_madd_epi16(xmm1, one); 267 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 268 269 _mm_storeu_si128( 270 reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]), 271 sums_of_each_slice_xmm); 272 dst_ptr += kCellSize; 273 } 274 dst_ptr += 3 * kCellSize * kCells; 275 } 276 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 277 } 278 }; 279 280 } // namespace gemmlowp 281 282 #endif // GEMMLOWP_INTERNAL_PACK_AVX_H_ 283