• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
16 
17 #ifdef __SSSE3__
18 
19 #include <emmintrin.h>  // SSE2
20 #include <tmmintrin.h>  // SSSE3
21 #ifdef __SSE4_1__
22 #include <smmintrin.h>  // SSE4.1
23 #endif
24 
25 #include <cstdint>
26 
27 #include "ruy/profiler/instrumentation.h"  // from @ruy
28 #include "tensorflow/lite/kernels/cpu_backend_context.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 
33 namespace tflite {
34 namespace tensor_utils {
35 namespace {
36 
37 // Dot product of four int8 vectors of 4 elements packed into a XMM register.
38 // Result is four int32 scalars packed into a XMM register.
39 // int8x4x4 · int8x4x4 => int32x4
DotProdInt8x4x4(__m128i a_8x16,__m128i b_8x16)40 static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) {
41   // Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned.
42   b_8x16 = _mm_sign_epi8(b_8x16, a_8x16);
43   a_8x16 = _mm_abs_epi8(a_8x16);
44   // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7)
45   __m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16);
46   // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3)
47   return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
48 }
49 
50 // Horizontally add 4 int32 values stored in a single XMM register to int32_t.
ReduceInt32x4(__m128i acc)51 static inline int32_t ReduceInt32x4(__m128i acc) {
52   // Shuffle to contain high half of acc (both in high and low halfs).
53   __m128i shuffle = _mm_unpackhi_epi64(acc, acc);
54   // Add shuffle and acc; low half is sums of twos (high half is ignored).
55   acc = _mm_add_epi32(acc, shuffle);
56   // Shuffle the two elements in low half (ignore high half).
57   shuffle = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
58   // Add shuffle and acc; lowest element is sum of all 4 input.
59   acc = _mm_add_epi32(acc, shuffle);
60   // Return lowest element as int32_t.
61   return _mm_cvtsi128_si32(acc);
62 }
63 
64 // Horizontally add each of 4 XMM registers with 4 int32 values, pack result
65 // into a single XMM register. Similar to ReduceInt32x4, but with 4x inputs.
ReduceInt32x4x4(__m128i a,__m128i b,__m128i c,__m128i d)66 static inline __m128i ReduceInt32x4x4(__m128i a, __m128i b, __m128i c,
67                                       __m128i d) {
68   // Assuming x = [x0, x1, x2, x3]
69   const __m128i a_b_lo_half = _mm_unpacklo_epi32(a, b);  // [a0, b0, a1, b1]
70   const __m128i a_b_hi_half = _mm_unpackhi_epi32(a, b);  // [a2, b2, a3, b3]
71   const __m128i a_plus_b =
72       _mm_add_epi32(a_b_lo_half, a_b_hi_half);  // [a0+a2, b0+b2, a1+a3, b1+b3]
73   const __m128i c_d_lo_half = _mm_unpacklo_epi32(c, d);  // [c0, d0, c1, d1]
74   const __m128i c_d_hi_half = _mm_unpackhi_epi32(c, d);  // [c2, d2, c3, d3]
75   const __m128i c_plus_d =
76       _mm_add_epi32(c_d_lo_half, c_d_hi_half);  // [c0+c2, d0+d2, c1+c3, d1+d3]
77   const __m128i all_evns =
78       _mm_unpacklo_epi64(a_plus_b, c_plus_d);  // [a02, b02, c02, d02]
79   const __m128i all_odds =
80       _mm_unpackhi_epi64(a_plus_b, c_plus_d);  // [a13, b13, c13, d13]
81   return _mm_add_epi32(all_evns, all_odds);    // [a0123, b0123, c0123, d0123]
82 }
83 
84 // Returns the ith element of a XMM register holding float numbers.
85 template <int i>
GetFloatVectorElement(__m128 v)86 float GetFloatVectorElement(__m128 v) {
87   static_assert(i >= 0 && i < 4, "The index must be 0 <= i < 4.");
88   // Note, _mm_extract_ps returns int, so we can't use it here.
89   // These lines will be optimized to extractps anyway.
90   v = _mm_shuffle_ps(v, v, _MM_SHUFFLE(i, i, i, i));
91   return _mm_cvtss_f32(v);
92 }
93 
94 }  // namespace
95 
SseMatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,const int32_t * row_sums)96 void SseMatrixBatchVectorMultiplyAccumulateImpl(
97     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
98     const int8_t* __restrict__ vectors,
99     const float* __restrict__ scaling_factors, int n_batch,
100     float* __restrict__ result, const float* per_channel_scale,
101     const int32_t* input_offset, const int32_t* row_sums) {
102   for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
103     const float batch_scaling_factor = scaling_factors[batch];
104     const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
105     // Compute dot-product for every column.
106     for (std::intptr_t row = 0; row < m_rows; ++row) {
107       // Get the address of the first element of the row.
108       const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
109       const float row_scale =
110           per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
111                             : batch_scaling_factor;
112       const int32_t row_offset =
113           row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
114       // Initialize the dot product sum for the row to 0.
115       __m128i dotprod_32x4 = _mm_setzero_si128();
116       std::intptr_t col = 0;
117       // For every block of 16x 8-bit inputs.
118       while (col < (m_cols & ~15)) {
119         const __m128i vec_8x16 =
120             _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
121         const __m128i row_8x16 =
122             _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
123         // dotprod += vec · row
124         dotprod_32x4 =
125             _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
126         col += 16;
127       }
128 #ifdef __SSE4_1__
129       // Postamble for 8x 8-bit inputs.
130       if (col < (m_cols & ~7)) {
131         const __m128i vec_16x8 = _mm_cvtepi8_epi16(
132             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
133         const __m128i row_16x8 = _mm_cvtepi8_epi16(
134             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
135         // dotprod += vec · row
136         dotprod_32x4 =
137             _mm_add_epi32(dotprod_32x4, _mm_madd_epi16(vec_16x8, row_16x8));
138         col += 8;
139       }
140       // Postamble for 4x 8-bit inputs.
141       if (col < (m_cols & ~3)) {
142         const __m128i vec_32x4 = _mm_cvtepi8_epi32(
143             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
144         const __m128i row_32x4 = _mm_cvtepi8_epi32(
145             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
146         // dotprod += vec · row
147         dotprod_32x4 =
148             _mm_add_epi32(dotprod_32x4, _mm_mullo_epi32(vec_32x4, row_32x4));
149         col += 4;
150       }
151 #endif
152 
153       // Horizontally add the 4 intermediate sum values to get the final
154       // dot-prod value for this row.
155       int32_t sum = ReduceInt32x4(dotprod_32x4);
156 
157 #if defined(__SSE4_1__) && defined(__clang__)
158       // SSE 4.1: Don't try to unroll and vectorize this, already done above.
159 #pragma clang loop unroll(disable) vectorize(disable)
160 #endif
161       // Postamble loop for <4x (<16x without SSE 4.1) remaining 8-bit inputs.
162       for (; col < m_cols; ++col) {
163         sum += row_ptr[col] * vectors[col];
164       }  // for col
165       if (row_offset) {
166         sum -= row_offset;
167       }
168       *result += sum * row_scale;
169       ++result;
170     }  // for row
171 
172     vectors += m_cols;
173   }  // for batch
174 }
175 
SseCpuBackendGemm(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,CpuBackendContext * context)176 void SseCpuBackendGemm(const int8_t* input, const int32_t* bias,
177                        const int8_t* input_to_gate_weights, int32_t n_batch,
178                        int32_t n_input, int32_t n_output, int32_t output_zp,
179                        int32_t* scratch, CpuBackendContext* context) {
180   using ::tflite::cpu_backend_gemm::Gemm;
181   using ::tflite::cpu_backend_gemm::GemmParams;
182   using ::tflite::cpu_backend_gemm::MatrixParams;
183 
184   MatrixParams<int8_t> lhs_params;
185   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
186   lhs_params.rows = n_output;
187   lhs_params.cols = n_input;
188   lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
189 
190   MatrixParams<int8_t> rhs_params;
191   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
192   rhs_params.rows = n_input;
193   rhs_params.cols = n_batch;
194 
195   MatrixParams<int32_t> dst_params;
196   dst_params.order = cpu_backend_gemm::Order::kColMajor;
197   dst_params.rows = n_output;
198   dst_params.cols = n_batch;
199 
200   GemmParams<int32, int32> gemm_params;
201   if (bias) {
202     gemm_params.bias = bias;
203   }
204   cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
205                          dst_params, scratch, gemm_params, context);
206 }
207 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result)208 void SseMatrixBatchVectorMultiplyAccumulate(
209     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
210     const int8_t* __restrict__ vectors,
211     const float* __restrict__ scaling_factors, int n_batch,
212     float* __restrict__ result) {
213   SseMatrixBatchVectorMultiplyAccumulateImpl(
214       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
215       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
216       /*row_sums=*/nullptr);
217 }
218 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)219 void SseMatrixBatchVectorMultiplyAccumulate(
220     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
221     const int8_t* __restrict__ vectors,
222     const float* __restrict__ scaling_factors, int n_batch, int32_t* scratch,
223     float* __restrict__ result, CpuBackendContext* context) {
224   if (m_rows % 4 == 0) {
225     const int32_t* bias = static_cast<const int32_t*>(nullptr);
226     SseCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
227                       /*output_zp=*/0, scratch, context);
228 
229     {
230       ruy::profiler::ScopeLabel label("HybridMultiplyScalingFactor");
231       // Multiply by float scaling factors and write to result
232       const int total_size = n_batch * m_rows;
233       int i = 0;
234       for (; i <= total_size - 8; i += 8, result += 8) {
235         const float batch_scaling_factor0 = scaling_factors[i / m_rows];
236         const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
237         const __m128 scaling_factor0 = _mm_set1_ps(batch_scaling_factor0);
238         const __m128 scaling_factor1 = _mm_set1_ps(batch_scaling_factor1);
239         const __m128i scratch_val0 =
240             _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i));
241         const __m128i scratch_val1 =
242             _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i + 4));
243         const __m128 float_val0 = _mm_cvtepi32_ps(scratch_val0);
244         const __m128 float_val1 = _mm_cvtepi32_ps(scratch_val1);
245         const __m128 prod0 = _mm_mul_ps(float_val0, scaling_factor0);
246         const __m128 result0 = _mm_add_ps(_mm_load1_ps(result), prod0);
247         const __m128 prod1 = _mm_mul_ps(float_val1, scaling_factor1);
248         const __m128 result1 = _mm_add_ps(_mm_load1_ps(result + 4), prod1);
249         _mm_store_ps(result, result0);
250         _mm_store_ps(result + 4, result1);
251       }
252       scratch += i;
253       for (; i < total_size; i++) {
254         const float batch_scaling_factor = scaling_factors[i / m_rows];
255         int32_t x = *(scratch++);
256         *result += x * batch_scaling_factor;
257         ++result;
258       }
259     }
260     return;
261   }
262 
263   SseMatrixBatchVectorMultiplyAccumulateImpl(
264       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
265       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
266       /*row_sums=*/nullptr);
267 }
268 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)269 void SseMatrixBatchVectorMultiplyAccumulate(
270     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
271     const int8_t* __restrict__ vectors,
272     const float* __restrict__ scaling_factors, int n_batch,
273     float* __restrict__ result, const float* per_channel_scale,
274     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
275     bool* compute_row_sums, CpuBackendContext* context) {
276   if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) {
277     SseReductionSumVector(matrix, row_sums, m_rows, m_cols);
278     if (compute_row_sums) {
279       *compute_row_sums = false;
280     }
281   }
282   SseMatrixBatchVectorMultiplyAccumulateImpl(
283       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
284       per_channel_scale, input_offset, row_sums);
285 }
286 
287 namespace {
288 
289 // Implements sparse-matrix - vector multiply-accumulate.
SseSparseMatrixVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vector,const float scaling_factor,float * __restrict__ result)290 inline void SseSparseMatrixVectorMultiplyAccumulate(
291     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
292     const int m_rows, const int m_cols, const int8_t* __restrict__ vector,
293     const float scaling_factor, float* __restrict__ result) {
294   static const std::intptr_t kBlockSize = 16;
295   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
296   const uint8_t* __restrict__ ledger_ptr = ledger;
297   for (std::intptr_t row = 0; row < m_rows; ++row) {
298     // Initialize the dot product sum for the row to 0.
299     __m128i dotprod_32x4 = _mm_setzero_si128();
300     std::intptr_t num_nonzero_blocks = *ledger_ptr++;
301     for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
302       const std::intptr_t col_index = *ledger_ptr++ * kBlockSize;
303       const __m128i vec_8x16 =
304           _mm_loadu_si128(reinterpret_cast<const __m128i*>(vector + col_index));
305       const __m128i row_8x16 =
306           _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
307       // dotprod += vec · row
308       dotprod_32x4 =
309           _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
310       matrix += kBlockSize;
311     }  // for col
312     // Horizontally add the 4 intermediate sum values to get the final
313     // dot-prod value for this row.
314     int32_t dotprod = ReduceInt32x4(dotprod_32x4);
315 
316     result[row] += dotprod * scaling_factor;
317   }  // for row
318 }
319 
320 // Implements sparse-matrix - batch-of-4-vectors multiply-accumulate.
321 // The stride between vectors and results must be equal to m_cols.
322 // Parameter 'batch' is the index of the first batch, must be a multiple of 4.
SseSparseMatrix4VectorsMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ const vectors,const __m128 scaling_factors_fx4,float * __restrict__ const results)323 inline void SseSparseMatrix4VectorsMultiplyAccumulate(
324     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
325     const int m_rows, const int m_cols,
326     const int8_t* __restrict__ const vectors, const __m128 scaling_factors_fx4,
327     float* __restrict__ const results) {
328   static const std::intptr_t kBlockSize = 16;
329   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
330 
331   const int8_t* __restrict__ vector0 = vectors + 0 * m_cols;
332   const int8_t* __restrict__ vector1 = vectors + 1 * m_cols;
333   const int8_t* __restrict__ vector2 = vectors + 2 * m_cols;
334   const int8_t* __restrict__ vector3 = vectors + 3 * m_cols;
335   float* __restrict__ result0 = results + 0 * m_rows;
336   float* __restrict__ result1 = results + 1 * m_rows;
337   float* __restrict__ result2 = results + 2 * m_rows;
338   float* __restrict__ result3 = results + 3 * m_rows;
339 
340   for (std::intptr_t row = 0; row < m_rows; ++row) {
341     // Initialize the dot product sum for the row to 0.
342     __m128i dp0_32x4 = _mm_setzero_si128();
343     __m128i dp1_32x4 = _mm_setzero_si128();
344     __m128i dp2_32x4 = _mm_setzero_si128();
345     __m128i dp3_32x4 = _mm_setzero_si128();
346 
347     std::intptr_t num_nonzero_blocks = *ledger++;
348     for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
349       const std::intptr_t col_index = *ledger++ * kBlockSize;
350       // vecN are for different batches
351       const __m128i vec0_8x16 = _mm_loadu_si128(
352           reinterpret_cast<const __m128i*>(vector0 + col_index));
353       const __m128i vec1_8x16 = _mm_loadu_si128(
354           reinterpret_cast<const __m128i*>(vector1 + col_index));
355       const __m128i vec2_8x16 = _mm_loadu_si128(
356           reinterpret_cast<const __m128i*>(vector2 + col_index));
357       const __m128i vec3_8x16 = _mm_loadu_si128(
358           reinterpret_cast<const __m128i*>(vector3 + col_index));
359       const __m128i row_8x16 =
360           _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
361       // dp += vec · row
362       // dpN are for different batches
363       dp0_32x4 = _mm_add_epi32(dp0_32x4, DotProdInt8x4x4(row_8x16, vec0_8x16));
364       dp1_32x4 = _mm_add_epi32(dp1_32x4, DotProdInt8x4x4(row_8x16, vec1_8x16));
365       dp2_32x4 = _mm_add_epi32(dp2_32x4, DotProdInt8x4x4(row_8x16, vec2_8x16));
366       dp3_32x4 = _mm_add_epi32(dp3_32x4, DotProdInt8x4x4(row_8x16, vec3_8x16));
367       matrix += kBlockSize;
368     }  // for col
369 
370     // Horizontally add the 4 intermediate values.
371     const __m128i dp_32x4 =
372         ReduceInt32x4x4(dp0_32x4, dp1_32x4, dp2_32x4, dp3_32x4);
373     // Convert to float
374     const __m128 dp_fx4 = _mm_cvtepi32_ps(dp_32x4);
375     // Load the results (This is an Accumulate function..)
376     __m128 result_fx4 =
377         _mm_set_ps(result3[row], result2[row], result1[row], result0[row]);
378     // result += dp .* scaling
379     result_fx4 =
380         _mm_add_ps(result_fx4, _mm_mul_ps(dp_fx4, scaling_factors_fx4));
381     // Save the results
382     result0[row] = GetFloatVectorElement<0>(result_fx4);
383     result1[row] = GetFloatVectorElement<1>(result_fx4);
384     result2[row] = GetFloatVectorElement<2>(result_fx4);
385     result3[row] = GetFloatVectorElement<3>(result_fx4);
386   }  // for row
387 }
388 
389 }  // namespace
390 
SseSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ results)391 void SseSparseMatrixBatchVectorMultiplyAccumulate(
392     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
393     const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
394     const float* __restrict__ scaling_factors, int n_batch,
395     float* __restrict__ results) {
396   int batch = 0;
397   const int kBatchSize4 = 4;
398   const int n_batch_rounddown_to_batchsize_4 = n_batch & ~(kBatchSize4 - 1);
399   while (batch < n_batch_rounddown_to_batchsize_4) {
400     const __m128 scaling_factors_fx4 = _mm_loadu_ps(scaling_factors + batch);
401     SseSparseMatrix4VectorsMultiplyAccumulate(
402         matrix, ledger, m_rows, m_cols, vectors, scaling_factors_fx4, results);
403     batch += kBatchSize4;
404     vectors += kBatchSize4 * m_cols;
405     results += kBatchSize4 * m_rows;
406   }  // for batch
407   while (batch < n_batch) {
408     SseSparseMatrixVectorMultiplyAccumulate(matrix, ledger, m_rows, m_cols,
409                                             vectors, scaling_factors[batch],
410                                             results);
411     ++batch;
412     vectors += m_cols;
413     results += m_rows;
414   }  // for batch
415 }
416 
SseReductionSumVector(const int8_t * input_vector,int32_t * output_vector,const int output_size,const int reduction_size)417 void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
418                            const int output_size, const int reduction_size) {
419   static constexpr std::intptr_t kBlockSize = 16;
420   for (std::intptr_t row = 0; row < output_size; ++row) {
421     const int8_t* __restrict__ row_ptr = input_vector + row * reduction_size;
422     __m128i row_sum_16x8 = _mm_setzero_si128();
423     std::intptr_t col = 0;
424     for (; col < (reduction_size & ~(kBlockSize - 1)); col += kBlockSize) {
425       const __m128i row_8x16 =
426           _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
427       const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16);
428       row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
429     }  // for col
430 #ifdef __SSE4_1__
431     // Postamble for 8x 8-bit inputs.
432     if (col < (reduction_size & ~7)) {
433       // _mm_loadu_si64 not supported in gcc versions < 9, breaks kokoro build.
434       const __m128i row_16x8 = _mm_cvtepi8_epi16(
435           _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
436       // dotprod += vec · row
437       row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
438       col += 8;
439     }
440 #endif
441     const __m128i row_sum_32x4 =
442         _mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1));
443     int32_t row_sum = ReduceInt32x4(row_sum_32x4);
444 #if defined(__SSE4_1__) && defined(__clang__)
445     // SSE 4.1: Don't try to unroll and vectorize this, already done above.
446 #pragma clang loop unroll(disable) vectorize(disable)
447 #endif
448     for (; col < reduction_size; col++) {
449       row_sum += row_ptr[col];
450     }
451     output_vector[row] = row_sum;
452   }
453 }
454 
455 }  // namespace tensor_utils
456 }  // namespace tflite
457 
458 #endif  // __SSSE3__
459