• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
16 
17 #include <math.h>
18 
19 #include <gmock/gmock.h>
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/kernels/cpu_backend_context.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/test_util.h"
25 
26 #ifdef DOTPROD_BENCHMARKS
27 #include "testing/base/public/benchmark.h"
28 #endif  // DOTPROD_BENCHMARKS
29 
30 namespace tflite {
31 namespace tensor_utils {
32 
33 // Normally we should require bit-for-bit exact results. Unfortunately a bug
34 // in the Intel arm_neon_sse.h translation header that we use for x86 tests
35 // causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes
36 // off-by-1 errors. So we have to live with a
37 // few off-by-one errors for now, yet still ensure that no more than a small
38 // minority of values are wrong.
39 // This util is to compare the rounding results for integer-output.
40 template <typename T>
CompareRoundingResults(int flat_size,const T * expected_result,const T * real_result,int max_element_tolerance=1,int max_total_tolerance=5)41 void CompareRoundingResults(int flat_size, const T* expected_result,
42                             const T* real_result, int max_element_tolerance = 1,
43                             int max_total_tolerance = 5) {
44   int max_diff = 0;
45   int64_t total_diff = 0;
46   for (int i = 0; i < flat_size; i++) {
47     int diff = static_cast<int>(std::abs(expected_result[i] - real_result[i]));
48     total_diff += diff;
49     max_diff = std::max(max_diff, diff);
50   }
51 
52   EXPECT_LE(max_diff, max_element_tolerance);
53   EXPECT_LE(total_diff, max_total_tolerance);
54 }
55 
TEST(uKernels,FloorLog2Test)56 TEST(uKernels, FloorLog2Test) {
57   for (int i = 1; i < 257; ++i) {
58     EXPECT_EQ(::tflite::FloorLog2(i),
59               static_cast<int>(std::floor(std::log2(i))));
60   }
61 }
62 
TEST(uKernels,VectorScalarMultiply)63 TEST(uKernels, VectorScalarMultiply) {
64   constexpr int kVectorSize = 29;
65   static int8_t input[kVectorSize];
66   for (int i = 0; i < 29; ++i) {
67     input[i] = static_cast<int8_t>(i - 14);
68   }
69   const float scale = 0.1f;
70   std::vector<float> output(kVectorSize, 0.0f);
71   VectorScalarMultiply(input, kVectorSize, scale, output.data());
72   EXPECT_THAT(output,
73               ElementsAreArray(ArrayFloatNear(
74                   {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5,
75                    -0.4, -0.3, -0.2, -0.1, 0,    0.1,  0.2,  0.3,  0.4,  0.5,
76                    0.6,  0.7,  0.8,  0.9,  1.0,  1.1,  1.2,  1.3,  1.4})));
77 }
78 
79 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
80 
81 // Test if a float array if full of zero values.
TEST(uKernels,IsZeroFloatTest)82 TEST(uKernels, IsZeroFloatTest) {
83   // Single NEON vector (= 4 floats)
84   {
85     const float four_zeros[4] = {0, 0, 0, 0};
86     EXPECT_TRUE(IsZeroVector(four_zeros, ARRAY_SIZE(four_zeros)));
87   }
88   {
89     const float four_nonzeros[4] = {1, 2, 3, 4};
90     EXPECT_FALSE(IsZeroVector(four_nonzeros, ARRAY_SIZE(four_nonzeros)));
91   }
92   // Multiple NEON vectors
93   {
94     const float eight_zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0};
95     EXPECT_TRUE(IsZeroVector(eight_zeros, ARRAY_SIZE(eight_zeros)));
96   }
97   {
98     const float eight_nonzeros[8] = {1, 2, 3, 4, 5, 6, 7, 8};
99     EXPECT_FALSE(IsZeroVector(eight_nonzeros, ARRAY_SIZE(eight_nonzeros)));
100   }
101   {
102     const float multiple_four_mixed1[8] = {0, 0, 0, 0, 5, 6, 7, 8};
103     EXPECT_FALSE(
104         IsZeroVector(multiple_four_mixed1, ARRAY_SIZE(multiple_four_mixed1)));
105   }
106   {
107     const float multiple_four_mixed2[8] = {1, 2, 3, 4, 0, 0, 0, 0};
108     EXPECT_FALSE(
109         IsZeroVector(multiple_four_mixed2, ARRAY_SIZE(multiple_four_mixed2)));
110   }
111   // less than one NEON vector
112   {
113     const float three_zeros[3] = {0, 0, 0};
114     EXPECT_TRUE(IsZeroVector(three_zeros, ARRAY_SIZE(three_zeros)));
115   }
116   {
117     const float three_nonzeros[3] = {1, 2, 3};
118     EXPECT_FALSE(IsZeroVector(three_nonzeros, ARRAY_SIZE(three_nonzeros)));
119   }
120   {
121     const float three_mixed[3] = {1, 0, 3};
122     EXPECT_FALSE(IsZeroVector(three_mixed, ARRAY_SIZE(three_mixed)));
123   }
124   // Postamble after NEON vectors
125   {
126     const float seven_zeros[7] = {0, 0, 0, 0, 0, 0, 0};
127     EXPECT_TRUE(IsZeroVector(seven_zeros, ARRAY_SIZE(seven_zeros)));
128   }
129   {
130     const float seven_nonzeros[7] = {1, 2, 3, 4, 5, 6, 7};
131     EXPECT_FALSE(IsZeroVector(seven_nonzeros, ARRAY_SIZE(seven_nonzeros)));
132   }
133   {
134     const float nonzeros_after_zeros[7] = {0, 0, 0, 0, 5, 6, 7};
135     EXPECT_FALSE(
136         IsZeroVector(nonzeros_after_zeros, ARRAY_SIZE(nonzeros_after_zeros)));
137   }
138 }
139 
140 // Test if an int8 array if full of zero values.
TEST(uKernels,IsZeroInt8Test)141 TEST(uKernels, IsZeroInt8Test) {
142   // Single NEON vector (= 16x int8_t)
143   {
144     const int8_t sixteen_zeros[16] = {0, 0, 0, 0, 0, 0, 0, 0,
145                                       0, 0, 0, 0, 0, 0, 0, 0};
146     EXPECT_TRUE(IsZeroVector(sixteen_zeros, ARRAY_SIZE(sixteen_zeros)));
147   }
148   {
149     const int8_t sixteen_nonzeros[16] = {1, 2,  3,  4,  5,  6,  7,  8,
150                                          9, 10, 11, 12, 13, 14, 15, 16};
151     EXPECT_FALSE(IsZeroVector(sixteen_nonzeros, ARRAY_SIZE(sixteen_nonzeros)));
152   }
153   // Multiple NEON vectors
154   {
155     const int8_t thritytwo_zeros[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
156                                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157                                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
158     EXPECT_TRUE(IsZeroVector(thritytwo_zeros, ARRAY_SIZE(thritytwo_zeros)));
159   }
160   {
161     const int8_t thritytwo_nonzeros[32] = {
162         1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
163         1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
164     EXPECT_FALSE(
165         IsZeroVector(thritytwo_nonzeros, ARRAY_SIZE(thritytwo_nonzeros)));
166   }
167   {
168     const int8_t thritytwo_mixed1[32] = {1,  2,  3,  4,  5,  6, 7, 8, 9, 10, 11,
169                                          12, 13, 14, 15, 16, 0, 0, 0, 0, 0,  0,
170                                          0,  0,  0,  0,  0,  0, 0, 0, 0, 0};
171     EXPECT_FALSE(IsZeroVector(thritytwo_mixed1, ARRAY_SIZE(thritytwo_mixed1)));
172   }
173   {
174     const int8_t thritytwo_mixed2[32] = {0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
175                                          0, 0, 0, 0,  0,  1,  2,  3,  4,  5, 6,
176                                          7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
177     EXPECT_FALSE(IsZeroVector(thritytwo_mixed2, ARRAY_SIZE(thritytwo_mixed2)));
178   }
179   // less than one NEON vector
180   {
181     const int8_t fifteen_zeros[15] = {0, 0, 0, 0, 0, 0, 0, 0,
182                                       0, 0, 0, 0, 0, 0, 0};
183     EXPECT_TRUE(IsZeroVector(fifteen_zeros, ARRAY_SIZE(fifteen_zeros)));
184   }
185   {
186     const int8_t fifteen_nonzeros[15] = {1, 2,  3,  4,  5,  6,  7, 8,
187                                          9, 10, 11, 12, 13, 14, 15};
188     EXPECT_FALSE(IsZeroVector(fifteen_nonzeros, ARRAY_SIZE(fifteen_nonzeros)));
189   }
190   {
191     const int8_t fifteen_mixed[15] = {1, 0, 3,  0, 5,  0, 7, 0,
192                                       9, 0, 11, 0, 13, 0, 15};
193     EXPECT_FALSE(IsZeroVector(fifteen_mixed, ARRAY_SIZE(fifteen_mixed)));
194   }
195   // Postamble after NEON vectors
196   {
197     const int8_t seventeen_zeros[17] = {0, 0, 0, 0, 0, 0, 0, 0,
198                                         0, 0, 0, 0, 0, 0, 0};
199     EXPECT_TRUE(IsZeroVector(seventeen_zeros, ARRAY_SIZE(seventeen_zeros)));
200   }
201   {
202     const int8_t seventeen_nonzeros[17] = {1,  2,  3,  4,  5,  6,  7,  8, 9,
203                                            10, 11, 12, 13, 14, 15, 16, 17};
204     EXPECT_FALSE(
205         IsZeroVector(seventeen_nonzeros, ARRAY_SIZE(seventeen_nonzeros)));
206   }
207   {
208     const int8_t nonzeros_after_zeros[17] = {0, 0, 0, 0, 0, 0, 0, 0,
209                                              0, 0, 0, 0, 0, 0, 17};
210     EXPECT_FALSE(
211         IsZeroVector(nonzeros_after_zeros, ARRAY_SIZE(nonzeros_after_zeros)));
212   }
213 }
214 
215 #undef ARRAY_SIZE
216 
TEST(uKernels,SymmetricQuantizeFloatsTest)217 TEST(uKernels, SymmetricQuantizeFloatsTest) {
218   constexpr int kVectorSize = 9;
219   static float input[kVectorSize] = {-640, -635.0, -630, 10.0,  2.0,
220                                      -5.0, -10.0,  0.0,  1000.0};
221 
222   int8_t output[kVectorSize];
223   float min, max, scaling_factor;
224   SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
225                           &scaling_factor);
226 
227   EXPECT_EQ(min, -640);
228   EXPECT_EQ(max, 1000);
229   // EQ won't work due to fpoint.
230   EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6);
231   EXPECT_THAT(output,
232               testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
233 }
234 
TEST(uKernels,SymmetricQuantizeFloatsAllZerosTest)235 TEST(uKernels, SymmetricQuantizeFloatsAllZerosTest) {
236   constexpr int kVectorSize = 9;
237   static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
238 
239   int8_t output[kVectorSize];
240   float min, max, scaling_factor;
241   SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
242                           &scaling_factor);
243 
244   EXPECT_EQ(min, 0);
245   EXPECT_EQ(max, 0);
246   EXPECT_EQ(scaling_factor, 1);
247   EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
248 }
249 
TEST(uKernels,SymmetricQuantizeFloatsAllAlmostZeroTest)250 TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
251   constexpr int kVectorSize = 9;
252   static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
253                                      4e-5,  9e-6, 2e-4,  0};
254 
255   int8_t output[kVectorSize];
256   float min, max, scaling_factor;
257   SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
258                           &scaling_factor);
259 
260   EXPECT_NEAR(min, -9e-05, 1e-6);
261   EXPECT_NEAR(max, 0.0002, 1e-6);
262   EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6);
263   EXPECT_THAT(output,
264               testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
265 }
266 
TEST(uKernels,AsymmetricQuantizeFloatsTest)267 TEST(uKernels, AsymmetricQuantizeFloatsTest) {
268   constexpr int kVectorSize = 9;
269   static float input[kVectorSize] = {-640, -635.0, -630, 10.0,  2.0,
270                                      -5.0, -10.0,  0.0,  1000.0};
271   int8_t output[kVectorSize];
272   double min = -640.0;
273   double max = 1000.0;
274   QuantizationParams quantization_params =
275       ChooseQuantizationParams<int8_t>(min, max);
276   float scale = quantization_params.scale;
277   int32_t offset = quantization_params.zero_point;
278   float test_scale;
279   int32_t test_offset;
280   AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
281                            &test_offset);
282   // EQ won't work due to fpoint.
283   EXPECT_NEAR(test_scale, scale, 1e-6);
284   EXPECT_EQ(test_offset, offset);
285   EXPECT_THAT(output, testing::ElementsAreArray(
286                           {-128, -127, -126, -26, -28, -29, -30, -28, 127}));
287 }
288 
TEST(uKernels,AsymmetricQuantizeFloatsAllZerosTest)289 TEST(uKernels, AsymmetricQuantizeFloatsAllZerosTest) {
290   constexpr int kVectorSize = 9;
291   static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
292   int8_t output[kVectorSize];
293   float test_scale;
294   int32_t test_offset;
295   AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
296                            &test_offset);
297   EXPECT_EQ(test_scale, 1);
298   EXPECT_EQ(test_offset, 0);
299   EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
300 }
301 
TEST(uKernels,AsymmetricQuantizeFloatsZeroRangeTest)302 TEST(uKernels, AsymmetricQuantizeFloatsZeroRangeTest) {
303   constexpr int kVectorSize = 9;
304   static float input[kVectorSize] = {2000, 2000, 2000, 2000, 2000,
305                                      2000, 2000, 2000, 2000};
306   int8_t output[kVectorSize];
307   double min = 0;
308   double max = 2000;
309   QuantizationParams quantization_params =
310       ChooseQuantizationParams<int8_t>(min, max);
311   int32_t offset = quantization_params.zero_point;
312   float scale = quantization_params.scale;
313   float test_scale;
314   int32_t test_offset;
315   AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
316                            &test_offset);
317   EXPECT_NEAR(test_scale, scale, 1e-6);
318   EXPECT_EQ(test_offset, offset);
319   EXPECT_THAT(output, testing::ElementsAreArray(
320                           {127, 127, 127, 127, 127, 127, 127, 127, 127}));
321 }
322 
TEST(uKernels,AsymmetricQuantizeFloatsAllAlmostZeroTest)323 TEST(uKernels, AsymmetricQuantizeFloatsAllAlmostZeroTest) {
324   constexpr int kVectorSize = 9;
325   static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
326                                      4e-5,  9e-6, 2e-4,  0};
327   int8_t output[kVectorSize];
328   double min = -9e-05;
329   double max = 0.0002;
330   QuantizationParams quantization_params =
331       ChooseQuantizationParams<int8_t>(min, max);
332   int32_t offset = quantization_params.zero_point;
333   float scale = quantization_params.scale;
334   float test_scale;
335   int32_t test_offset;
336   AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
337                            &test_offset);
338   EXPECT_NEAR(test_scale, scale, 1e-6);
339   EXPECT_EQ(test_offset, offset);
340   EXPECT_THAT(output, testing::ElementsAreArray(
341                           {-58, -23, -55, -128, -48, -14, -41, 127, -49}));
342 }
343 
TEST(uKernels,MatrixBatchVectorMultiplyAccumulateTest)344 TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
345   constexpr int kRow = 3;
346   constexpr int kCol = 4;
347   constexpr int kBatch = 2;
348   static float matrix[kRow * kCol] = {1.0,  2.0,  3.0,  4.0,   //
349                                       -1.0, -2.0, -3.0, -4.0,  //
350                                       1.0,  -2.0, 3.0,  -4.0};
351   static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0,  //
352                                         2.0, -2.0, 2.0, -2.0};
353   std::vector<float> output(kRow * kBatch);
354   std::fill(output.begin(), output.end(), 3.0);
355   MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
356                                       output.data());
357   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13.,  //
358                                                        -1., 7., 23.})));
359 }
360 
361 // Quantized matmul with 2 * 30 input and 9 * 30 matrix.
TEST(uKernels,QuantMatrixBatchVectorMultiplyAccumulate8x8_16Test)362 TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_16Test) {
363   CpuBackendContext context;
364   const std::vector<int8_t> input = {
365       4,   -41, 5,   -41, 22,  17, -30, 24,  13,  -47, 18, 9,   -11, -30, 16,
366       -47, 12,  36,  -20, 27,  -3, 0,   -51, -31, 3,   -8, -38, 43,  23,  12,
367       11,  -23, -26, 23,  14,  -9, -44, 22,  21,  -30, 3,  -47, -26, -21, -24,
368       -44, 34,  -11, -23, -28, 26, -38, 19,  35,  9,   23, 6,   -42, -25, 28,
369   };
370   const std::vector<int32_t> input_zeropoint_times_weights = {
371       -620, -170, -395, 715, -1220, -1080, 1130, -260, -470,
372   };
373   const std::vector<int8_t> input_to_gate_weights = {
374       -10, -4,  -8,  16,  4,   -16, -1,  11,  1,   2,   -25, 19,  7,   9,   2,
375       -24, -2,  10,  -7,  7,   -5,  -2,  3,   4,   3,   -4,  -7,  -11, -13, -18,
376       11,  10,  12,  -9,  17,  -15, -5,  20,  -6,  -11, 2,   -6,  -18, 15,  4,
377       4,   -9,  -2,  -3,  -9,  -13, 17,  -21, 5,   3,   -12, 0,   -4,  9,   -5,
378       10,  -2,  8,   1,   -10, -6,  1,   -9,  10,  11,  -1,  -5,  4,   -7,  -4,
379       -4,  4,   12,  -7,  -5,  -9,  -19, 6,   -4,  12,  -17, -22, 0,   9,   -4,
380       -5,  5,   -8,  8,   3,   15,  -18, -18, 5,   3,   -12, 5,   -10, 7,   7,
381       -9,  17,  2,   -11, -25, 3,   19,  -6,  7,   1,   7,   5,   -3,  11,  3,
382       0,   -8,  8,   -2,  -2,  -12, 14,  -5,  7,   8,   16,  20,  -16, -5,  -5,
383       1,   -10, -6,  14,  10,  -12, 10,  -6,  5,   0,   3,   8,   -9,  -13, -2,
384       4,   4,   -16, -17, -9,  16,  -5,  14,  -9,  -5,  -12, 0,   17,  6,   -1,
385       16,  -20, 1,   -11, -1,  -10, -21, 13,  4,   -12, -7,  0,   -14, -6,  3,
386       -4,  6,   -18, -3,  -1,  14,  -8,  -6,  -15, 5,   12,  -3,  -10, 4,   6,
387       -5,  -20, 0,   3,   -3,  -7,  1,   2,   -10, 7,   -3,  6,   1,   -12, 6,
388       4,   -12, 2,   6,   -20, 0,   5,   23,  15,  14,  9,   8,   20,  -2,  9,
389       -8,  -8,  -7,  -4,  -8,  -9,  7,   -12, -2,  2,   1,   -14, 31,  4,   -14,
390       3,   10,  -18, -17, -1,  18,  1,   12,  0,   7,   -3,  -5,  8,   -9,  18,
391       17,  7,   -15, 3,   20,  4,   -8,  16,  6,   -3,  -3,  9,   -4,  -6,  4,
392   };
393   const int32_t multiplier = 2080364544;
394   const int32_t shift = -2;
395 
396   std::vector<int32_t> scratch(2 * 9, 0);
397   std::vector<int16_t> output = {10, 2, 33, 4, 5,  6,  65, 4,  3,
398                                  52, 1, 2,  8, -1, -2, 11, 17, -18};
399   MatrixBatchVectorMultiplyAccumulate(
400       input.data(), input_zeropoint_times_weights.data(),
401       input_to_gate_weights.data(), multiplier, shift,
402       /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, /*output_zp=*/0,
403       scratch.data(), output.data(), &context);
404   const std::vector<int16_t> expected_output = {
405       -210, 331,  153, 139, -570, -657, 258, 515,  -495,
406       91,   -243, -73, 603, -744, -269, 169, -748, -174,
407   };
408 
409   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
410 }
411 
TEST(uKernels,HybridMatrixBatchVectorMultiplyAccumulate8x8_16Test)412 TEST(uKernels, HybridMatrixBatchVectorMultiplyAccumulate8x8_16Test) {
413   CpuBackendContext context;
414   const std::vector<int8_t> input = {
415       4,   -41, 5,   -41, 22,  17,  -30, 24,  13,  -47, 18,  9,   -11, -30, 16,
416       1,   -47, 12,  36,  -20, 27,  -3,  0,   -51, -31, 3,   -8,  -38, 43,  23,
417       12,  1,   11,  -23, -26, 23,  14,  -9,  -44, 22,  21,  -30, 3,   -47, -26,
418       -21, -24, 1,   -44, 34,  -11, -23, -28, 26,  -38, 19,  35,  9,   23,  6,
419       -42, -25, 28,  1,   4,   -41, 5,   -41, 22,  17,  -30, 24,  13,  -47, 18,
420       9,   -11, -30, 16,  1,   -47, 12,  36,  -20, 27,  -3,  0,   -51, -31, 3,
421       -8,  -38, 43,  23,  12,  1,   11,  -23, -26, 23,  14,  -9,  -44, 22,  21,
422       -30, 3,   -47, -26, -21, -24, 1,   -44, 34,  -11, -23, -28, 26,  -38, 19,
423       35,  9,   23,  6,   -42, -25, 28,  1,
424   };
425   const std::vector<int32_t> input_offsets = {1, 1, 1, 1};
426 
427   const std::vector<float> scaling_factors = {
428       1.0,
429       1.0,
430       1.0,
431       1.0,
432   };
433 
434   const std::vector<int8_t> input_to_gate_weights = {
435       -10, -4,  -8,  16,  4,  -16, -1,  11,  1,   2,   -25, 19,  7,   9,   2,
436       1,   -24, -2,  10,  -7, 7,   -5,  -2,  3,   4,   3,   -4,  -7,  -11, -13,
437       -18, 2,   11,  10,  12, -9,  17,  -15, -5,  20,  -6,  -11, 2,   -6,  -18,
438       15,  4,   3,   4,   -9, -2,  -3,  -9,  -13, 17,  -21, 5,   3,   -12, 0,
439       -4,  9,   -5,  4,   10, -2,  8,   1,   -10, -6,  1,   -9,  10,  11,  -1,
440       -5,  4,   -7,  -4,  5,  -4,  4,   12,  -7,  -5,  -9,  -19, 6,   -4,  12,
441       -17, -22, 0,   9,   -4, 6,   -5,  5,   -8,  8,   3,   15,  -18, -18, 5,
442       3,   -12, 5,   -10, 7,  7,   7,   -9,  17,  2,   -11, -25, 3,   19,  -6,
443       7,   1,   7,   5,   -3, 11,  3,   8,   0,   -8,  8,   -2,  -2,  -12, 14,
444       -5,  7,   8,   16,  20, -16, -5,  -5,  9,   1,   -10, -6,  14,  10,  -12,
445       10,  -6,  5,   0,   3,  8,   -9,  -13, -2,  10,  4,   4,   -16, -17, -9,
446       16,  -5,  14,  -9,  -5, -12, 0,   17,  6,   -1,  11,  16,  -20, 1,   -11,
447       -1,  -10, -21, 13,  4,  -12, -7,  0,   -14, -6,  3,   12,  -4,  6,   -18,
448       -3,  -1,  14,  -8,  -6, -15, 5,   12,  -3,  -10, 4,   6,   13,  -5,  -20,
449       0,   3,   -3,  -7,  1,  2,   -10, 7,   -3,  6,   1,   -12, 6,   14,  -5,
450       -20, 0,   3,   -3,  -7, 1,   2,   -10, 7,   -3,  6,   1,   -12, 6,   15,
451       -5,  -20, 0,   3,   -3, -7,  1,   2,   -10, 7,   -3,  6,   1,   -12, 6,
452       16,
453   };
454 
455   std::vector<int32_t> scratch(5 * 8, 0);
456   std::vector<float> output(4 * 8, 0);
457   int32_t* row_sums = scratch.data() + 8 * 4;
458   bool compute_row_sums = true;
459   MatrixBatchVectorMultiplyAccumulate(
460       input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32, input.data(),
461       scaling_factors.data(), /*n_batch*/ 4, output.data(), nullptr,
462       input_offsets.data(), scratch.data(), row_sums, &compute_row_sums,
463       &context);
464 
465   const std::vector<float_t> expected_output = {
466       -228, 1548,  937, -166, -1164, -1578, -278,  303, 839,  -820,  132,
467       1733, -1858, 58,  -425, -587,  -228,  1548,  937, -166, -1164, -1578,
468       -278, 303,   839, -820, 132,   1733,  -1858, 58,  -425, -587,
469   };
470 
471   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
472   EXPECT_THAT(compute_row_sums, false);
473 
474   std::vector<float> output2(4 * 8, 0);
475   MatrixBatchVectorMultiplyAccumulate(
476       input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32, input.data(),
477       scaling_factors.data(), /*n_batch*/ 4, output2.data(), nullptr,
478       input_offsets.data(), scratch.data(), row_sums, &compute_row_sums,
479       &context);
480 
481   EXPECT_THAT(output2, testing::ElementsAreArray(expected_output));
482 
483   // Run with a large batch size to trigger the CpuBackendGemm path on any
484   // device.
485   constexpr int kBatchMultiplier = 8;
486   std::vector<int8_t> input_big_batch(input.size() * kBatchMultiplier);
487   std::vector<float> scaling_factors_big_batch(scaling_factors.size() *
488                                                kBatchMultiplier);
489   std::vector<int32_t> scratch_big_batch(scratch.size() * kBatchMultiplier);
490   std::vector<int32_t> input_offsets_big_batch(input_offsets.size() *
491                                                kBatchMultiplier);
492   for (int i = 0; i < kBatchMultiplier; i++) {
493     std::copy(input.begin(), input.end(),
494               input_big_batch.begin() + i * input.size());
495     std::copy(scaling_factors.begin(), scaling_factors.end(),
496               scaling_factors_big_batch.begin() + i * scaling_factors.size());
497     std::copy(input_offsets.begin(), input_offsets.end(),
498               input_offsets_big_batch.begin() + i * input_offsets.size());
499   }
500   std::vector<float> output_big_batch(output.size() * kBatchMultiplier, 0);
501   MatrixBatchVectorMultiplyAccumulate(
502       input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32,
503       input_big_batch.data(), scaling_factors_big_batch.data(),
504       /*n_batch*/ 4 * kBatchMultiplier, output_big_batch.data(), nullptr,
505       input_offsets_big_batch.data(), scratch_big_batch.data(), row_sums,
506       &compute_row_sums, &context);
507   for (int i = 0; i < kBatchMultiplier; i++) {
508     std::vector<float> output_per_batch(
509         output_big_batch.begin() + i * output.size(),
510         output_big_batch.begin() + (i + 1) * output.size());
511     EXPECT_THAT(output_per_batch, testing::ElementsAreArray(expected_output));
512   }
513 }
514 
515 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix.
TEST(uKernels,QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test)516 TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test) {
517   CpuBackendContext context;
518   const std::vector<int8_t> input = {
519       4,   -41, 5,   -41, 22,  17, -30, 24,  13,  -47, 18, 9,   -11, -30, 16,
520       -47, 12,  36,  -20, 27,  -3, 0,   -51, -31, 3,   -8, -38, 43,  23,  12,
521       11,  -23, -26, 23,  14,  -9, -44, 22,  21,  -30, 3,  -47, -26, -21, -24,
522       -44, 34,  -11, -23, -28, 26, -38, 19,  35,  9,   23, 6,   -42, -25, 28,
523   };
524   const std::vector<int32_t> input_zeropoint_times_weights = {
525       0, 0, 0, 0, 0, 0, 0, 0, 0,
526   };
527   const std::vector<int8_t> input_to_gate_weights = {
528       13,  -7,  -20, -22, 8,   -46, 9,   -2,  -18, -42, 40,  28,  -7,  24,  34,
529       -7,  -24, -24, 19,  14,  -19, -6,  -2,  -3,  5,   -36, -13, 6,   -27, 36,
530       -23, 0,   20,  -37, -23, 9,   17,  -41, 33,  -15, -18, -42, -41, -34, -16,
531       -6,  12,  -14, -15, -20, -14, 21,  -3,  -1,  -26, 54,  51,  35,  -14, 9,
532       -2,  13,  -6,  39,  34,  -21, 39,  -51, 19,  -44, 52,  0,   -2,  -38, -35,
533       -33, 4,   -22, -37, 27,  -23, 3,   -10, 5,   32,  6,   1,   -35, 24,  -19,
534       46,  43,  -55, 5,   38,  -14, 32,  -43, -44, -17, -13, -28, 56,  28,  -42,
535       4,   10,  -7,  25,  -15, -9,  -25, -14, -15, 6,   -10, -22, 40,  -72, 18,
536       -6,  -18, -2,  37,  -13, -10, 11,  -9,  32,  -28, 19,  -2,  4,   -31, 50,
537       -15, 23,  -34, -9,  41,  -6,  -34, 17,  2,   24,  -15, 21,  -17, -8,  -20,
538       1,   -63, 19,  -40, 12,  -5,  5,   -6,  1,   19,  -9,  -23, 5,   -34, 11,
539       26,  21,  54,  34,  -43, -29, 1,   16,  31,  -56, -28, 57,  -15, -23, 37,
540       -17, -3,  -6,  29,  18,  77,  17,  -20, -14, -19, 8,   -24, -7,  -45, -3,
541       0,   -25, -8,  6,   9,   3,   -15, 51,  4,   -15, -19, -16, -14, -47, -52,
542       25,  9,   58,  26,  -9,  -27, 49,  -6,  -21, 21,  18,  12,  -9,  -9,  14,
543       31,  -26, -19, -50, 17,  35,  11,  -10, 22,  -16, -43, -2,  26,  55,  -20,
544       -7,  21,  33,  -20, 26,  -15, -22, 30,  27,  3,   -34, 26,  12,  -1,  19,
545       26,  -25, 10,  30,  30,  -14, -23, -23, -35, -16, 26,  -41, 11,  1,   21,
546   };
547   const int32_t multiplier = 1347771520;
548   const int32_t shift = -7;
549   const int32_t output_zp = -11;
550 
551   std::vector<int8_t> output = {1, 2, 3, 4, 5,  6,  5,  4,  3,
552                                 2, 1, 2, 8, -1, -2, 11, 17, 18};
553   std::vector<int32_t> scratch(2 * 9, 0);
554   MatrixBatchVectorMultiplyAccumulate(
555       input.data(), input_zeropoint_times_weights.data(),
556       input_to_gate_weights.data(), multiplier, shift,
557       /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, output_zp, scratch.data(),
558       output.data(), &context);
559   const std::vector<int8_t> expected_output = {
560       5,   -9, -2, -30, -5, -11, -22, -18, 18,
561       -19, 2,  11, -5,  9,  -2,  10,  -38, -22,
562   };
563 
564   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
565 }
566 
567 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels,QuantMatrixBatchVectorMultiply8x8_8WithZPTest)568 TEST(uKernels, QuantMatrixBatchVectorMultiply8x8_8WithZPTest) {
569   const int32_t input_zp = 3;
570   const std::vector<int8_t> input = {
571       4,   -41, 5,   -41, 22,  17, -30, 24,  13,  -47, 18, 9,   -11, -30, 16,
572       -47, 12,  36,  -20, 27,  -3, 0,   -51, -31, 3,   -8, -38, 43,  23,  12,
573       11,  -23, -26, 23,  14,  -9, -44, 22,  21,  -30, 3,  -47, -26, -21, -24,
574       -44, 34,  -11, -23, -28, 26, -38, 19,  35,  9,   23, 6,   -42, -25, 28,
575   };
576   const std::vector<int8_t> input_to_gate_weights = {
577       13,  -7,  -20, -22, 8,   -46, 9,   -2,  -18, -42, 40,  28,  -7,  24,  34,
578       -7,  -24, -24, 19,  14,  -19, -6,  -2,  -3,  5,   -36, -13, 6,   -27, 36,
579       -23, 0,   20,  -37, -23, 9,   17,  -41, 33,  -15, -18, -42, -41, -34, -16,
580       -6,  12,  -14, -15, -20, -14, 21,  -3,  -1,  -26, 54,  51,  35,  -14, 9,
581       -2,  13,  -6,  39,  34,  -21, 39,  -51, 19,  -44, 52,  0,   -2,  -38, -35,
582       -33, 4,   -22, -37, 27,  -23, 3,   -10, 5,   32,  6,   1,   -35, 24,  -19,
583       46,  43,  -55, 5,   38,  -14, 32,  -43, -44, -17, -13, -28, 56,  28,  -42,
584       4,   10,  -7,  25,  -15, -9,  -25, -14, -15, 6,   -10, -22, 40,  -72, 18,
585       -6,  -18, -2,  37,  -13, -10, 11,  -9,  32,  -28, 19,  -2,  4,   -31, 50,
586       -15, 23,  -34, -9,  41,  -6,  -34, 17,  2,   24,  -15, 21,  -17, -8,  -20,
587       1,   -63, 19,  -40, 12,  -5,  5,   -6,  1,   19,  -9,  -23, 5,   -34, 11,
588       26,  21,  54,  34,  -43, -29, 1,   16,  31,  -56, -28, 57,  -15, -23, 37,
589       -17, -3,  -6,  29,  18,  77,  17,  -20, -14, -19, 8,   -24, -7,  -45, -3,
590       0,   -25, -8,  6,   9,   3,   -15, 51,  4,   -15, -19, -16, -14, -47, -52,
591       25,  9,   58,  26,  -9,  -27, 49,  -6,  -21, 21,  18,  12,  -9,  -9,  14,
592       31,  -26, -19, -50, 17,  35,  11,  -10, 22,  -16, -43, -2,  26,  55,  -20,
593       -7,  21,  33,  -20, 26,  -15, -22, 30,  27,  3,   -34, 26,  12,  -1,  19,
594       26,  -25, 10,  30,  30,  -14, -23, -23, -35, -16, 26,  -41, 11,  1,   21,
595   };
596   const int32_t multiplier = 1347771520;
597   const int32_t shift = -7;
598   const int32_t output_zp = -11;
599 
600   std::vector<int8_t> output = {1, 2, 3, 4, 5,  6,  5,  4,  3,
601                                 2, 1, 2, 8, -1, -2, 11, 17, 18};
602 
603   MatrixBatchVectorMultiply(
604       input.data(), input_zp, input_to_gate_weights.data(), multiplier, shift,
605       /*n_batch=*/2, /*n_input=*/30, /*n_cell=*/9, output.data(), output_zp);
606   const std::vector<int8_t> expected_output = {6,   -9,  -4, -32, -10, -17,
607                                                -25, -25, 14, -19, 3,   10,
608                                                -12, 10,  0,  1,   -57, -41};
609 
610   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
611 }
612 
613 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels,QuantMatrixBatchVectorMultiply16x8_8WithZPTest)614 TEST(uKernels, QuantMatrixBatchVectorMultiply16x8_8WithZPTest) {
615   const std::vector<int16_t> input = {
616       400, -41, 5,   -41, 22,  17, -30, 24,  130, -47, 18, 9,   -11, -30, 16,
617       -47, 12,  36,  -20, 27,  -3, 0,   -51, -31, 3,   -8, -38, 43,  23,  12,
618       11,  -23, -26, 23,  14,  -9, -44, 22,  21,  -30, 3,  -47, -26, -21, -24,
619       -44, 34,  -11, -23, -28, 26, -38, 19,  35,  9,   23, 6,   -42, -25, 28,
620   };
621   const std::vector<int8_t> input_to_gate_weights = {
622       13,  -7,  -20, -22, 8,   -46, 9,   -2,  -18, -42, 40,  28,  -7,  24,  34,
623       -7,  -24, -24, 19,  14,  -19, -6,  -2,  -3,  5,   -36, -13, 6,   -27, 36,
624       -23, 0,   20,  -37, -23, 9,   17,  -41, 33,  -15, -18, -42, -41, -34, -16,
625       -6,  12,  -14, -15, -20, -14, 21,  -3,  -1,  -26, 54,  51,  35,  -14, 9,
626       -2,  13,  -6,  39,  34,  -21, 39,  -51, 19,  -44, 52,  0,   -2,  -38, -35,
627       -33, 4,   -22, -37, 27,  -23, 3,   -10, 5,   32,  6,   1,   -35, 24,  -19,
628       46,  43,  -55, 5,   38,  -14, 32,  -43, -44, -17, -13, -28, 56,  28,  -42,
629       4,   10,  -7,  25,  -15, -9,  -25, -14, -15, 6,   -10, -22, 40,  -72, 18,
630       -6,  -18, -2,  37,  -13, -10, 11,  -9,  32,  -28, 19,  -2,  4,   -31, 50,
631       -15, 23,  -34, -9,  41,  -6,  -34, 17,  2,   24,  -15, 21,  -17, -8,  -20,
632       1,   -63, 19,  -40, 12,  -5,  5,   -6,  1,   19,  -9,  -23, 5,   -34, 11,
633       26,  21,  54,  34,  -43, -29, 1,   16,  31,  -56, -28, 57,  -15, -23, 37,
634       -17, -3,  -6,  29,  18,  77,  17,  -20, -14, -19, 8,   -24, -7,  -45, -3,
635       0,   -25, -8,  6,   9,   3,   -15, 51,  4,   -15, -19, -16, -14, -47, -52,
636       25,  9,   58,  26,  -9,  -27, 49,  -6,  -21, 21,  18,  12,  -9,  -9,  14,
637       31,  -26, -19, -50, 17,  35,  11,  -10, 22,  -16, -43, -2,  26,  55,  -20,
638       -7,  21,  33,  -20, 26,  -15, -22, 30,  27,  3,   -34, 26,  12,  -1,  19,
639       26,  -25, 10,  30,  30,  -14, -23, -23, -35, -16, 26,  -41, 11,  1,   21,
640   };
641 
642   const std::vector<int32_t> input_zeropoint_times_weights = {
643       0, 2, 3, 4, 5, 4, 3, 2, 10,
644   };
645   const int32_t multiplier = 1347771520;
646   const int32_t shift = -8;
647   const int32_t output_zp = -11;
648 
649   std::vector<int8_t> output = {1, 2, 3, 4, 5,  6,  5,  4,  3,
650                                 2, 1, 2, 8, -1, -2, 11, 17, 18};
651 
652   MatrixBatchVectorMultiply(
653       input.data(), input_to_gate_weights.data(), multiplier, shift,
654       input_zeropoint_times_weights.data(),
655       /*n_batch=*/2, /*n_hidden=*/30, /*n_output=*/9, output_zp, output.data());
656   const std::vector<int8_t> expected_output = {4,   -24, -5, 10,  -7,  -13,
657                                                -39, 2,   3,  -16, -5,  -1,
658                                                -12, -1,  -6, -6,  -33, -25};
659 
660   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
661 }
662 
663 // Quantized matmul with 9 * 30 matrix.
TEST(uKernels,MatrixScalarMultiplyAccumulateTest)664 TEST(uKernels, MatrixScalarMultiplyAccumulateTest) {
665   std::vector<int32_t> output = {
666       -620, -170, -395, 715, -1220, -1080, 1130, -260, -470,
667   };
668   const std::vector<int8_t> weight = {
669       -10, -4,  -8,  16,  4,   -16, -1,  11,  1,   2,   -25, 19,  7,   9,   2,
670       -24, -2,  10,  -7,  7,   -5,  -2,  3,   4,   3,   -4,  -7,  -11, -13, -18,
671       11,  10,  12,  -9,  17,  -15, -5,  20,  -6,  -11, 2,   -6,  -18, 15,  4,
672       4,   -9,  -2,  -3,  -9,  -13, 17,  -21, 5,   3,   -12, 0,   -4,  9,   -5,
673       10,  -2,  8,   1,   -10, -6,  1,   -9,  10,  11,  -1,  -5,  4,   -7,  -4,
674       -4,  4,   12,  -7,  -5,  -9,  -19, 6,   -4,  12,  -17, -22, 0,   9,   -4,
675       -5,  5,   -8,  8,   3,   15,  -18, -18, 5,   3,   -12, 5,   -10, 7,   7,
676       -9,  17,  2,   -11, -25, 3,   19,  -6,  7,   1,   7,   5,   -3,  11,  3,
677       0,   -8,  8,   -2,  -2,  -12, 14,  -5,  7,   8,   16,  20,  -16, -5,  -5,
678       1,   -10, -6,  14,  10,  -12, 10,  -6,  5,   0,   3,   8,   -9,  -13, -2,
679       4,   4,   -16, -17, -9,  16,  -5,  14,  -9,  -5,  -12, 0,   17,  6,   -1,
680       16,  -20, 1,   -11, -1,  -10, -21, 13,  4,   -12, -7,  0,   -14, -6,  3,
681       -4,  6,   -18, -3,  -1,  14,  -8,  -6,  -15, 5,   12,  -3,  -10, 4,   6,
682       -5,  -20, 0,   3,   -3,  -7,  1,   2,   -10, 7,   -3,  6,   1,   -12, 6,
683       4,   -12, 2,   6,   -20, 0,   5,   23,  15,  14,  9,   8,   20,  -2,  9,
684       -8,  -8,  -7,  -4,  -8,  -9,  7,   -12, -2,  2,   1,   -14, 31,  4,   -14,
685       3,   10,  -18, -17, -1,  18,  1,   12,  0,   7,   -3,  -5,  8,   -9,  18,
686       17,  7,   -15, 3,   20,  4,   -8,  16,  6,   -3,  -3,  9,   -4,  -6,  4,
687   };
688   MatrixScalarMultiplyAccumulate(weight.data(), 3, 9, 30, output.data());
689   const std::vector<int32_t> expected_output = {
690       -797, -227, -536, 739, -1187, -1314, 965, -140, -257,
691   };
692 
693   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
694 }
695 
696 // Quantized layer norm of n_batch = 2 and n_input = 15.
TEST(uKernels,QuantApplyLayerNormTest)697 TEST(uKernels, QuantApplyLayerNormTest) {
698   const std::vector<int16_t> input = {
699       -310,  596,   34,   -68,  475,  92,  672, -54,  -913, -200,
700       -1194, -836,  -620, -237, 991,  533, 721, -736, -8,   -941,
701       -372,  -1084, 591,  2557, -779, 175, 582, 956,  -287, 944,
702   };
703   const std::vector<int16_t> layer_norm_weights = {
704       21849, 22882, 20626, 23854, 24779, 26354, 12980, 26231,
705       23716, 27271, 24937, 22647, 24715, 22854, 19646,
706   };
707   const std::vector<int32_t> bias_weight = {
708       -14175520, -13805465, -16027609, -13786809, -13321033,
709       -14399810, -15055368, -14536623, -14508746, -13784007,
710       -15206609, -15125830, -14996304, -14847597, -12814379,
711   };
712   const int32_t multiplier = 1895840000;
713   const int32_t shift = -13;
714   const int32_t limit = 1;
715 
716   std::vector<int16_t> output(2 * 15, 0);
717   ApplyLayerNorm(input.data(), layer_norm_weights.data(), bias_weight.data(),
718                  multiplier, shift, limit, 2, 15, output.data());
719   const std::vector<int16_t> expected_output = {
720       -9407,  5846,   -4802,  -5295,  4822,   -2390,  930,   -5283,
721       -20352, -7846,  -26539, -18704, -15829, -8627,  10313, -2522,
722       -132,   -16058, -8206,  -19158, -13296, -14407, -1235, 20612,
723       -18591, -6738,  -2274,  2602,   -11622, 1565,
724   };
725   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
726 }
727 
728 // Quantized layer norm of n_batch = 2 and n_input = 15.
TEST(uKernels,QuantApplyLayerNormFloatTest)729 TEST(uKernels, QuantApplyLayerNormFloatTest) {
730   const std::vector<int16_t> input = {
731       -310,  596,   34,   -68,  475,  92,  672, -54,  -913, -200,
732       -1194, -836,  -620, -237, 991,  533, 721, -736, -8,   -941,
733       -372,  -1084, 591,  2557, -779, 175, 582, 956,  -287, 944,
734   };
735   const std::vector<int16_t> layer_norm_weights = {
736       21849, 22882, 20626, 23854, 24779, 26354, 12980, 26231,
737       23716, 27271, 24937, 22647, 24715, 22854, 19646,
738   };
739   const std::vector<int32_t> bias_weight = {
740       -14175520, -13805465, -16027609, -13786809, -13321033,
741       -14399810, -15055368, -14536623, -14508746, -13784007,
742       -15206609, -15125830, -14996304, -14847597, -12814379,
743   };
744   const int32_t multiplier = 1895840000;
745   const int32_t shift = -13;
746 
747   std::vector<int16_t> output(2 * 15, 0);
748   ApplyLayerNormFloat(input.data(), layer_norm_weights.data(), multiplier,
749                       shift, bias_weight.data(), 2, 15, output.data());
750   const std::vector<int16_t> expected_output = {
751       -9408,  5844,   -4803,  -5297,  4826,   -2392,  927,   -5286,
752       -20353, -7851,  -26534, -18701, -15830, -8623,  10312, -2524,
753       -136,   -16053, -8206,  -19160, -13299, -14407, -1233, 20617,
754       -18594, -6736,  -2272,  2597,   -11620, 1566};
755 
756   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
757 }
758 
759 // Quantized tanh with Q0.15 input and Q0.15 output.
TEST(uKernels,QuantTanh0Test)760 TEST(uKernels, QuantTanh0Test) {
761   const std::vector<int16_t> input = {
762       -145, 899, -176, -35,  264, 289,  8,    27,   -37,  -1310,
763       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
764       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
765       -145, 899, -176, -35,  264, 289,  8,    27,   -37,  -1310,
766       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
767       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
768   };
769   std::vector<int16_t> output(4 * 15, 0);
770   ApplyTanh(0, input.data(), 4, 15, output.data());
771   const std::vector<int16_t> expected_output = {
772       -136, 904, -176, -40,  260, 292,  8,    28,   -44,  -1304,
773       -120, 120, -24,  112,  376, -576, -308, 88,   -544, 544,
774       652,  -32, -60,  1056, -56, -156, -144, -636, 192,  -1300,
775       -136, 904, -176, -40,  260, 292,  8,    28,   -44,  -1304,
776       -120, 120, -24,  112,  376, -576, -308, 88,   -544, 544,
777       652,  -32, -60,  1056, -56, -156, -144, -636, 192,  -1300,
778   };
779   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
780 }
781 
782 // Quantized tanh with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantTanh3Test)783 TEST(uKernels, QuantTanh3Test) {
784   const std::vector<int16_t> input = {
785       -145, 899, -176, -35,  264, 289,  8,    27,   -37,  -1310,
786       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
787       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
788       -145, 899, -176, -35,  264, 289,  8,    27,   -37,  -1310,
789       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
790       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
791   };
792   std::vector<int16_t> output(4 * 15, 0);
793   ApplyTanh(3, input.data(), 4, 15, output.data());
794   const std::vector<int16_t> expected_output = {
795       -1156, 7076, -1412, -276, 2104, 2308,  64,    220,   -288,  -10132,
796       -964,  1016, -120,  844,  2944, -4640, -2392, 736,   -4352, 4352,
797       5180,  -232, -428,  8276, -412, -1308, -1196, -5044, 1612,  -10044,
798       -1156, 7076, -1412, -276, 2104, 2308,  64,    220,   -288,  -10132,
799       -964,  1016, -120,  844,  2944, -4640, -2392, 736,   -4352, 4352,
800       5180,  -232, -428,  8276, -412, -1308, -1196, -5044, 1612,  -10044,
801   };
802   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
803 }
804 
805 // Quantized tanh with float calculation.
TEST(uKernels,QuantTanhFloatTest)806 TEST(uKernels, QuantTanhFloatTest) {
807   const std::vector<int16_t> input = {
808       -1,   0,   1,    -35,  264, 289,  8,    27,   -37,  -1310,
809       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
810       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
811       -145, 899, -176, -35,  264, 289,  8,    27,   -37,  -1310,
812       -120, 127, -16,  106,  370, -583, -299, 93,   -548, 548,
813       653,  -29, -53,  1058, -52, -164, -149, -635, 201,  -1297,
814   };
815   std::vector<int16_t> output(4 * 15, 0);
816   ApplyTanhFloat(input.data(), 4, 15, -12, output.data());
817   const std::vector<int16_t> expected_output = {
818       -8,    0,    8,     -279, 2109, 2308,  63,    215,   -295,  -10136,
819       -959,  1015, -127,  847,  2951, -4632, -2387, 743,   -4358, 4358,
820       5180,  -231, -423,  8280, -415, -1311, -1191, -5039, 1606,  -10042,
821       -1159, 7078, -1407, -279, 2109, 2308,  63,    215,   -295,  -10136,
822       -959,  1015, -127,  847,  2951, -4632, -2387, 743,   -4358, 4358,
823       5180,  -231, -423,  8280, -415, -1311, -1191, -5039, 1606,  -10042};
824 
825   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
826 }
827 
828 // Quantized tanh with Q4.11 input and Q0.15 output.
TEST(uKernels,QuantTanh4Test)829 TEST(uKernels, QuantTanh4Test) {
830   const std::vector<int16_t> input = {
831       -5,  163, -31, -5,  54, 90, 1,  2,  -4, -42, -8,  29,  0,   47, 150,
832       -26, -36, 9,   -73, 25, 14, -2, -1, 29, -10, -12, -18, -29, 51, -92,
833       -5,  163, -31, -5,  54, 90, 1,  2,  -4, -42, -8,  29,  0,   47, 150,
834       -26, -36, 9,   -73, 25, 14, -2, -1, 29, -10, -12, -18, -29, 51, -92,
835   };
836   std::vector<int16_t> output(4 * 15, 0);
837   ApplyTanh(4, input.data(), 4, 15, output.data());
838   const std::vector<int16_t> expected_output = {
839       -76,  2596, -496, -76, 856,  1436, 24,   36,   -64,   -672,
840       -120, 456,  0,    752, 2400, -412, -576, 148,  -1168, 400,
841       216,  -36,  -24,  456, -164, -192, -292, -456, 820,   -1476,
842       -76,  2596, -496, -76, 856,  1436, 24,   36,   -64,   -672,
843       -120, 456,  0,    752, 2400, -412, -576, 148,  -1168, 400,
844       216,  -36,  -24,  456, -164, -192, -292, -456, 820,   -1476,
845   };
846   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
847 }
848 
849 // Quantized sigmoid with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantSigmoidTest)850 TEST(uKernels, QuantSigmoidTest) {
851   const std::vector<int16_t> input = {
852       -10500, 1398,   -6963,  -7404,  485,    -5401,  -1757,  -7668,  -19248,
853       -9692,  -24249, -17923, -15840, -10026, 5249,   -89,    1787,   -16178,
854       -6691,  -19524, -13439, -24048, -1123,  32767,  -17267, -3378,  823,
855       11482,  -11139, 7508,   -10500, 1398,   -6963,  -7404,  485,    -5401,
856       -1757,  -7668,  -19248, -9692,  -24249, -17923, -15840, -10026, 5249,
857       -89,    1787,   -16178, -6691,  -19524, -13439, -24048, -1123,  32767,
858       -17267, -3378,  823,    11482,  -11139, 7508,
859   };
860   std::vector<int16_t> output(4 * 15, 0);
861   ApplySigmoid(input.data(), 4, 15, output.data());
862   const std::vector<int16_t> expected_output = {
863       2339, 19152, 5063,  4617,  17350, 6917,  12921, 4371,  299,  2813,
864       89,   409,   673,   2605,  25646, 16207, 19904, 615,   5353, 273,
865       1187, 91,    14153, 32756, 475,   9983,  18026, 30898, 2023, 28246,
866       2339, 19152, 5063,  4617,  17350, 6917,  12921, 4371,  299,  2813,
867       89,   409,   673,   2605,  25646, 16207, 19904, 615,   5353, 273,
868       1187, 91,    14153, 32756, 475,   9983,  18026, 30898, 2023, 28246,
869   };
870   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
871 }
872 
873 // Quantized sigmoid with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantSigmoidFloatTest)874 TEST(uKernels, QuantSigmoidFloatTest) {
875   const std::vector<int16_t> input = {
876       -10500, 1398,   -6963,  -7404,  485,    -5401,  -1757,  -7668,  -19248,
877       -9692,  -24249, -17923, -15840, -10026, 5249,   -89,    1787,   -16178,
878       -6691,  -19524, -13439, -24048, -1123,  32767,  -17267, -3378,  823,
879       11482,  -11139, 7508,   -10500, 1398,   -6963,  -7404,  485,    -5401,
880       -1757,  -7668,  -19248, -9692,  -24249, -17923, -15840, -10026, 5249,
881       -89,    1787,   -16178, -6691,  -19524, -13439, -24048, -1123,  32767,
882       -17267, -3378,  823,    11482,  -11139, 7508,
883   };
884   std::vector<int16_t> output(4 * 15, 0);
885   ApplySigmoidFloat(input.data(), 4, 15, output.data());
886   const std::vector<int16_t> expected_output = {
887       2343, 19153, 5061,  4617,  17352, 6915,  12922, 4368,  295,  2811,
888       87,   407,   671,   2608,  25647, 16206, 19902, 619,   5352, 276,
889       1187, 92,    14151, 32757, 476,   9986,  18024, 30895, 2026, 28249,
890       2343, 19153, 5061,  4617,  17352, 6915,  12922, 4368,  295,  2811,
891       87,   407,   671,   2608,  25647, 16206, 19902, 619,   5352, 276,
892       1187, 92,    14151, 32757, 476,   9986,  18024, 30895, 2026, 28249};
893 
894   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
895 }
896 
897 // Quantized Multiply with 16bit output and 15 bit shift.
TEST(uKernels,QuantMul16bitOut15ShiftTest)898 TEST(uKernels, QuantMul16bitOut15ShiftTest) {
899   const std::vector<int16_t> input1 = {
900       2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
901       4545, 14835, 1285,   29498, 26788,  2907,  7877,  6331,   8775,   3001,
902       1399, 4683,  1437,   1853,  12163,  4927,  7977,  3001,   16612,  4791,
903   };
904   const std::vector<int16_t> input2 = {
905       -1156, 32767, -32768, -32768, 32767, 2308,  64,    220,   -288,  -10132,
906       -964,  1016,  -120,   844,    2944,  -4640, -2392, 736,   -4352, 4352,
907       5180,  -232,  -428,   8276,   -412,  -1308, -1196, -5044, 1612,  -10044,
908   };
909   std::vector<int16_t> output(2 * 15, 0);
910   CwiseMul(input1.data(), input2.data(), 2, 15, 15, output.data());
911   const std::vector<int16_t> expected_output = {
912       -88,  32766, -32768, -32767, -32767, 2308, 64,   -220, 288,   -667,
913       -134, 460,   -5,     760,    2407,   -412, -575, 142,  -1165, 399,
914       221,  -33,   -19,    468,    -153,   -197, -291, -462, 817,   -1469,
915   };
916   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
917 }
918 
919 // Quantized Multiply with 16bit output and 19 bit shift.
TEST(uKernels,QuantMul16bitOut19ShiftTest)920 TEST(uKernels, QuantMul16bitOut19ShiftTest) {
921   const std::vector<int16_t> input1 = {
922       2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
923       4545, 14835, 1285,   29498, 26788,  2907,  7877,  6331,   8775,   3001,
924       1399, 4683,  1437,   1853,  12163,  4927,  7977,  3001,   16612,  4791,
925   };
926   const std::vector<int16_t> input2 = {
927       -1156, 32767, -32768, -32768, 32767, 2308,  64,    220,   -288,  -10132,
928       -964,  1016,  -120,   844,    2944,  -4640, -2392, 736,   -4352, 4352,
929       5180,  -232,  -428,   8276,   -412,  -1308, -1196, -5044, 1612,  -10044,
930   };
931   std::vector<int16_t> output(2 * 15, 0);
932   CwiseMul(input1.data(), input2.data(), 2, 15, 19, output.data());
933   const std::vector<int16_t> expected_output = {
934       -5, 2048, 2048, -2048, -2048, 144, 4,   -14, 18,  -42,
935       -8, 29,   0,    47,    150,   -26, -36, 9,   -73, 25,
936       14, -2,   -1,   29,    -10,   -12, -18, -29, 51,  -92,
937   };
938   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
939 }
940 
941 // Quantized Multiply with arbitrary scale.
TEST(uKernels,QuantMul8bitArbitrarySclaeTest)942 TEST(uKernels, QuantMul8bitArbitrarySclaeTest) {
943   // scale = 0.000028.
944   int multiplier = 1970324837;
945   int shift = -15;
946 
947   const std::vector<int16_t> input1 = {
948       2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
949       4545, 14835, 1285,   29498, 26788,  2907,  7877,  6331,   8775,   3001,
950       1399, 4683,  1437,   1853,  12163,  4927,  7977,  3001,   16612,  4791,
951   };
952   const std::vector<int16_t> input2 = {
953       -1156, 32767, -32768, -32768, 32767, 2308,  64,    220,   -288,  -10132,
954       -964,  1016,  -120,   844,    2944,  -4640, -2392, 736,   -4352, 4352,
955       5180,  -232,  -428,   8276,   -412,  -1308, -1196, -5044, 1612,  -10044,
956   };
957   std::vector<int8_t> output(2 * 15, 0);
958   CwiseMul(input1.data(), input2.data(), multiplier, shift, 2, 15, 3,
959            output.data());
960   const std::vector<int8_t> expected_output = {
961       -84,  127, 127, -128, -128, 127,  56,   -128, 127,  -128,
962       -126, 127, -7,  127,  127,  -128, -128, 127,  -128, 127,
963       127,  -33, -20, 127,  -128, -128, -128, -128, 127,  -128,
964   };
965   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
966 }
967 
968 // Quantized element wise Add with saturation.
TEST(uKernels,QuantAddTest)969 TEST(uKernels, QuantAddTest) {
970   const std::vector<int16_t> input1 = {
971       2491,   32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 20000,
972       -20000, 14835, 1285,   29498, 26788,  2907,  7877,  6331,   8775,   3001,
973       1399,   4683,  1437,   1853,  12163,  4927,  7977,  3001,   16612,  4791,
974   };
975   const std::vector<int16_t> input2 = {
976       -1156,  32767, -32768, -32768, 32767, 2308,  64,    220,   -288,  20000,
977       -20000, 1016,  -120,   844,    2944,  -4640, -2392, 736,   -4352, 4352,
978       5180,   -232,  -428,   8276,   -412,  -1308, -1196, -5044, 1612,  -10044,
979   };
980   std::vector<int16_t> output(2 * 15, 0);
981   CwiseAdd(input1.data(), input2.data(), 2, 15, output.data());
982   const std::vector<int16_t> expected_output = {
983       1335,   32767, -32768, -1,    -1,    32767, 32767, -32548, -32768, 32767,
984       -32768, 15851, 1165,   30342, 29732, -1733, 5485,  7067,   4423,   7353,
985       6579,   4451,  1009,   10129, 11751, 3619,  6781,  -2043,  18224,  -5253,
986   };
987   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
988 }
989 
TEST(uKernels,ClipTest)990 TEST(uKernels, ClipTest) {
991   constexpr int kVectorSize = 10;
992   constexpr float kAbsLimit = 2.0;
993   std::vector<float> input = {0.0,  -0.5, 1.0,  -1.5, 2.0,
994                               -2.5, 3.0,  -3.5, 4.0,  -4.5};
995   CwiseClipping(input.data(), kVectorSize, kAbsLimit);
996   const std::vector<float> expected_output = {0.0,  -0.5, 1.0,  -1.5, 2.0,
997                                               -2.0, 2.0,  -2.0, 2.0,  -2.0};
998   EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
999 }
1000 
1001 // Quantized clipping for 16 bit.
TEST(uKernels,QuantClip16Test)1002 TEST(uKernels, QuantClip16Test) {
1003   constexpr int kVectorSize = 30;
1004   constexpr int16_t kAbsLimit = 300;
1005   std::vector<int16_t> input = {
1006       -10500, 1,     -2,     -7404,  200,    -5401,  -1757, -7668,
1007       -19248, -9692, -24249, -17923, -15840, -10026, 5249,  -89,
1008       1787,   -200,  -6691,  -19524, -13439, -24048, -1123, 32767,
1009       -17267, -3378, 823,    11482,  -11139, 7508,
1010   };
1011   CwiseClipping(input.data(), kVectorSize, kAbsLimit);
1012   const std::vector<int16_t> expected_output = {
1013       -300, 1,    -2,   -300, 200,  -300, -300, -300, -300, -300,
1014       -300, -300, -300, -300, 300,  -89,  300,  -200, -300, -300,
1015       -300, -300, -300, 300,  -300, -300, 300,  300,  -300, 300,
1016   };
1017   EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
1018 }
1019 
1020 // Quantized clipping for 8 bit.
TEST(uKernels,QuantClip8Test)1021 TEST(uKernels, QuantClip8Test) {
1022   constexpr int kVectorSize = 30;
1023   constexpr int8_t kAbsLimit = 32;
1024   std::vector<int8_t> input = {
1025       4,   -11, -5, -34, -10, -17, -27, -22, 15,  127, -128, 1,  3, 56, 3,
1026       -21, 1,   9,  -13, 10,  0,   -1,  -55, -40, 127, -128, 11, 4, 6,  32,
1027   };
1028   CwiseClipping(input.data(), kVectorSize, kAbsLimit);
1029   const std::vector<int8_t> expected_output = {
1030       4,   -11, -5, -32, -10, -17, -27, -22, 15,  32, -32, 1,  3, 32, 3,
1031       -21, 1,   9,  -13, 10,  0,   -1,  -32, -32, 32, -32, 11, 4, 6,  32,
1032   };
1033   EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
1034 }
1035 
1036 struct MatrixVectorData {
1037   // Contains dense parameters.
1038   std::vector<int8_t> matrix;
1039 
1040   // Like matrix, but with about half of the parameters set to zero.
1041   // Use this to create golden output for sparse matrix tests.
1042   std::vector<int8_t> zeroed_matrix;
1043 
1044   // zeroed_matrix described in sparse form.
1045   std::vector<int8_t> sparse_matrix;
1046   std::vector<uint8_t> ledger;
1047 
1048   std::vector<int8_t> vectors;
1049   std::vector<float> scale_factors;
1050   std::vector<float> results;
1051 
1052   // Per channel scale data.
1053   std::vector<float> per_channel_scales;
1054   std::vector<int32_t> input_offsets;
1055 
1056   int rows;
1057   int cols;
1058   int batch;
1059 };
1060 
SetupMatrixVectorData(int rows,int cols,int batch,bool negative=false,bool is_per_channel=false,bool init_to_one=false)1061 MatrixVectorData SetupMatrixVectorData(int rows, int cols, int batch,
1062                                        bool negative = false,
1063                                        bool is_per_channel = false,
1064                                        bool init_to_one = false) {
1065   MatrixVectorData data;
1066   data.rows = rows;
1067   data.cols = cols;
1068   data.batch = batch;
1069 
1070   for (int i = 0; i < rows * cols; i++) {
1071     int sign = 1;
1072     if ((i % 3) == 0 && negative) sign = -1;
1073     data.matrix.push_back(sign * (i % 70));
1074   }
1075   for (int i = 0; i < cols * batch; i++) {
1076     int sign = 1;
1077     if ((i % 5) == 0 && negative) sign = -1;
1078     data.vectors.push_back(sign * (i % 50));
1079   }
1080   data.scale_factors = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8};
1081   data.results.resize(rows * batch, init_to_one ? 1 : 0);
1082 
1083   data.zeroed_matrix = data.matrix;
1084 
1085   // Make a sparsification ledger.
1086   for (int i = 0; i < rows; i++) {
1087     int max_chunks = cols / 16;
1088     int selected_chunks = (max_chunks / 2);
1089     bool row_is_odd = (i % 2) > 0;
1090     bool max_chunks_is_odd = (max_chunks % 2) > 0;
1091 
1092     data.ledger.push_back(selected_chunks);
1093     if (max_chunks_is_odd && row_is_odd) {
1094       selected_chunks++;
1095     }
1096 
1097     // In odd rows, use odd chunk indexes.
1098     // In even rows, use even chunk indexes.
1099     for (int j = 0; j < max_chunks; j++) {
1100       const int chunk_start = i * cols + (j * 16);
1101       const int chunk_end = i * cols + (j * 16) + 16;
1102       if ((j % 2) == (i % 2)) {
1103         // Copy this chunk into the sparse matrix.
1104         data.ledger.push_back(j);
1105         for (int k = chunk_start; k < chunk_end; k++) {
1106           data.sparse_matrix.push_back(data.matrix[k]);
1107         }
1108       } else {
1109         // Zero this part out of zeroed_matrix.
1110         for (int k = chunk_start; k < chunk_end; k++) {
1111           data.zeroed_matrix[k] = 0;
1112         }
1113       }
1114     }
1115   }
1116 
1117   if (is_per_channel) {
1118     for (int i = 0; i < rows; i++) {
1119       if (i % 2 == 0) {
1120         data.per_channel_scales.push_back(0.5);
1121       } else {
1122         data.per_channel_scales.push_back(1.0);
1123       }
1124     }
1125 
1126     for (int i = 0; i < batch; i++) {
1127       for (int j = 0; j < cols; j++) {
1128         data.vectors[i * cols + j] += i;
1129       }
1130       data.input_offsets.push_back(i);
1131     }
1132   }
1133   return data;
1134 }
1135 
TestDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false,bool init_to_one=false)1136 std::vector<float> TestDotprodMatrixBatchVectorMultiply(
1137     int rows, int cols, int batch, bool negative = false,
1138     bool init_to_one = false) {
1139   MatrixVectorData data =
1140       SetupMatrixVectorData(rows, cols, batch, negative, false, init_to_one);
1141 
1142   // All partial sums in this computation are small enough to fit in the
1143   // mantissa of a float, and the scale factors are all integers, so we expect
1144   // an exact result.
1145   MatrixBatchVectorMultiplyAccumulate(
1146       data.matrix.data(), rows, cols, data.vectors.data(),
1147       data.scale_factors.data(), batch, &data.results[0]);
1148   return data.results;
1149 }
1150 
TestSparseDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false)1151 std::vector<float> TestSparseDotprodMatrixBatchVectorMultiply(
1152     int rows, int cols, int batch, bool negative = false) {
1153   MatrixVectorData data = SetupMatrixVectorData(rows, cols, batch, negative);
1154   SparseMatrixBatchVectorMultiplyAccumulate(
1155       data.sparse_matrix.data(), data.ledger.data(), rows, cols,
1156       data.vectors.data(), data.scale_factors.data(), batch, &data.results[0]);
1157   return data.results;
1158 }
1159 
TestPerChannelDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false,bool is_per_channel=true)1160 std::vector<float> TestPerChannelDotprodMatrixBatchVectorMultiply(
1161     int rows, int cols, int batch, bool negative = false,
1162     bool is_per_channel = true) {
1163   MatrixVectorData data =
1164       SetupMatrixVectorData(rows, cols, batch, negative, is_per_channel);
1165   std::vector<int32_t> scratch(rows * batch);
1166   std::vector<int32_t> row_sums(rows);
1167   bool compute_row_sums = true;
1168   CpuBackendContext context;
1169   MatrixBatchVectorMultiplyAccumulate(
1170       data.matrix.data(), rows, cols, data.vectors.data(),
1171       data.scale_factors.data(), batch, &data.results[0],
1172       data.per_channel_scales.data(), data.input_offsets.data(), scratch.data(),
1173       row_sums.data(), &compute_row_sums, &context);
1174   return data.results;
1175 }
1176 
TEST(uKernels,DotprodMatrixBatchVectorMultiplyAccumulateTest)1177 TEST(uKernels, DotprodMatrixBatchVectorMultiplyAccumulateTest) {
1178   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 16, 1),
1179               testing::ElementsAre(1240, 3160, 5080, 7000));
1180 
1181   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 2),
1182               testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756,
1183                                    37416, 60916));
1184 
1185   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 3),
1186               testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756,
1187                                    37416, 60916, 52080, 142704, 55878, 125712));
1188 
1189   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(8, 1024, 3),
1190               testing::ElementsAreArray(
1191                   {841094,  853168,  866642,  840286,  860760,  862754,
1192                    843678,  872552,  1724476, 1769072, 1747588, 1738844,
1193                    1758240, 1742916, 1761612, 1755808, 2506896, 2564262,
1194                    2629188, 2515824, 2598390, 2569236, 2537352, 2645118}));
1195 
1196   const bool kNegative = true;
1197   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 64, 1, kNegative),
1198               testing::ElementsAre(13696, 6904, 7764, 11806));
1199   ASSERT_THAT(
1200       TestDotprodMatrixBatchVectorMultiply(4, 32, 2, kNegative),
1201       testing::ElementsAre(3436, 3522, 1590, 6972, 2516, 20520, 456, 10628));
1202 
1203   // Initialize the results vector with 1s to verify that the code adds
1204   // to the results vector instead of zero-ing it first.
1205   const bool kInitToOne = true;
1206   ASSERT_THAT(
1207       TestDotprodMatrixBatchVectorMultiply(4, 32, 2, kNegative, kInitToOne),
1208       testing::ElementsAre(3437, 3523, 1591, 6973, 2517, 20521, 457, 10629));
1209 }
1210 
TEST(uKernels,PerChannelDotprodMatrixBatchVectorMultiplyAccumulateTest)1211 TEST(uKernels, PerChannelDotprodMatrixBatchVectorMultiplyAccumulateTest) {
1212   ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 16, 1),
1213               testing::ElementsAre(1240 / 2, 3160, 5080 / 2, 7000));
1214 
1215   ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 32, 2),
1216               testing::ElementsAre(10416 / 2, 26288, 8490 / 2, 23312, 18276 / 2,
1217                                    70756, 37416 / 2, 60916));
1218 
1219   ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 32, 3),
1220               testing::ElementsAre(10416 / 2, 26288, 8490 / 2, 23312, 18276 / 2,
1221                                    70756, 37416 / 2, 60916, 52080 / 2, 142704,
1222                                    55878 / 2, 125712));
1223 
1224   ASSERT_THAT(
1225       TestPerChannelDotprodMatrixBatchVectorMultiply(8, 1024, 3),
1226       testing::ElementsAreArray(
1227           {841094 / 2,  853168,  866642 / 2,  840286,  860760 / 2,  862754,
1228            843678 / 2,  872552,  1724476 / 2, 1769072, 1747588 / 2, 1738844,
1229            1758240 / 2, 1742916, 1761612 / 2, 1755808, 2506896 / 2, 2564262,
1230            2629188 / 2, 2515824, 2598390 / 2, 2569236, 2537352 / 2, 2645118}));
1231 }
1232 
TEST(uKernels,DotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest)1233 TEST(uKernels, DotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest) {
1234   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 16, 4),
1235               testing::ElementsAreArray(
1236                   {1240, 3160, 6320, 18352, 15240, 45576, 4200, 16232}));
1237   ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 4),
1238               testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626,
1239                                          165312, 209864, 246128}));
1240   ASSERT_THAT(
1241       TestDotprodMatrixBatchVectorMultiply(2, 64, 8),
1242       testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626, 165312,
1243                                  209864, 246128, 219700, 195550, 279684, 278928,
1244                                  413616, 445662, 374896, 365952}));
1245 
1246   ASSERT_THAT(
1247       TestDotprodMatrixBatchVectorMultiply(4, 64, 8),
1248       testing::ElementsAreArray(
1249           {45794,  38948,  34622,  32816,  88536,  84252,  85008,  90804,
1250            157626, 165312, 180558, 203364, 209864, 246128, 236472, 208896,
1251            219700, 195550, 184000, 185050, 279684, 278928, 293292, 322776,
1252            413616, 445662, 495348, 513674, 374896, 365952, 321168, 296544}));
1253 
1254   ASSERT_THAT(
1255       TestDotprodMatrixBatchVectorMultiply(16, 1024, 4),
1256       testing::ElementsAreArray(
1257           {841094,  853168,  866642,  840286,  860760,  862754,  843678,
1258            872552,  837586,  851270,  877414,  834188,  863062,  857846,
1259            841780,  879054,  1724476, 1769072, 1747588, 1738844, 1758240,
1260            1742916, 1761612, 1755808, 1737684, 1750780, 1747356, 1754152,
1261            1748348, 1753324, 1743320, 1754316, 2506896, 2564262, 2629188,
1262            2515824, 2598390, 2569236, 2537352, 2645118, 2508444, 2571480,
1263            2610576, 2510442, 2618208, 2566584, 2544570, 2614536, 3458904,
1264            3502688, 3474792, 3505976, 3499360, 3488264, 3485848, 3512832,
1265            3500616, 3482520, 3489624, 3469008, 3495992, 3524376, 3465680,
1266            3526264}));
1267 
1268   ASSERT_THAT(
1269       TestDotprodMatrixBatchVectorMultiply(4, 128, 4),
1270       testing::ElementsAreArray({87920, 80024, 92288, 103712, 228148, 224820,
1271                                  233812, 213124, 271284, 271788, 332772, 328236,
1272                                  419328, 431328, 411968, 417248}));
1273 
1274   ASSERT_THAT(
1275       TestDotprodMatrixBatchVectorMultiply(4, 128, 8),
1276       testing::ElementsAreArray(
1277           {87920,  80024,  92288,  103712, 228148, 224820, 233812, 213124,
1278            271284, 271788, 332772, 328236, 419328, 431328, 411968, 417248,
1279            482680, 523840, 560800, 593560, 563940, 609924, 566868, 644772,
1280            743708, 857780, 818972, 823284, 708384, 695008, 730912, 872096}));
1281 
1282   const bool kNegative = true;
1283   EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(1, 16, 1, kNegative),
1284               testing::ElementsAre(450));
1285   EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 8, kNegative),
1286               testing::ElementsAreArray({13696, 6904, 9952, 12368, 22848, 61632,
1287                                          40424, 46776, 57630, 38670, 62976,
1288                                          49824, 39032, 71988, 60128, 148992}));
1289 
1290   std::vector<float> results =
1291       TestDotprodMatrixBatchVectorMultiply(256, 1024, 8);
1292   int64_t sum = 0;
1293   for (int i = 0; i < results.size(); i++) {
1294     sum += static_cast<int64_t>(results[i]);
1295   }
1296   EXPECT_EQ(7980076336, sum);
1297 }
1298 
TEST(uKernels,PerChannelDotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest)1299 TEST(uKernels,
1300      PerChannelDotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest) {
1301   ASSERT_THAT(
1302       TestPerChannelDotprodMatrixBatchVectorMultiply(16, 1024, 4),
1303       testing::ElementsAreArray(
1304           {841094 / 2,  853168,  866642 / 2,  840286,  860760 / 2,  862754,
1305            843678 / 2,  872552,  837586 / 2,  851270,  877414 / 2,  834188,
1306            863062 / 2,  857846,  841780 / 2,  879054,  1724476 / 2, 1769072,
1307            1747588 / 2, 1738844, 1758240 / 2, 1742916, 1761612 / 2, 1755808,
1308            1737684 / 2, 1750780, 1747356 / 2, 1754152, 1748348 / 2, 1753324,
1309            1743320 / 2, 1754316, 2506896 / 2, 2564262, 2629188 / 2, 2515824,
1310            2598390 / 2, 2569236, 2537352 / 2, 2645118, 2508444 / 2, 2571480,
1311            2610576 / 2, 2510442, 2618208 / 2, 2566584, 2544570 / 2, 2614536,
1312            3458904 / 2, 3502688, 3474792 / 2, 3505976, 3499360 / 2, 3488264,
1313            3485848 / 2, 3512832, 3500616 / 2, 3482520, 3489624 / 2, 3469008,
1314            3495992 / 2, 3524376, 3465680 / 2, 3526264}));
1315 
1316   ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 128, 4),
1317               testing::ElementsAreArray(
1318                   {87920 / 2, 80024, 92288 / 2, 103712, 228148 / 2, 224820,
1319                    233812 / 2, 213124, 271284 / 2, 271788, 332772 / 2, 328236,
1320                    419328 / 2, 431328, 411968 / 2, 417248}));
1321 
1322   ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 128, 8),
1323               testing::ElementsAreArray(
1324                   {87920 / 2,  80024,  92288 / 2,  103712, 228148 / 2, 224820,
1325                    233812 / 2, 213124, 271284 / 2, 271788, 332772 / 2, 328236,
1326                    419328 / 2, 431328, 411968 / 2, 417248, 482680 / 2, 523840,
1327                    560800 / 2, 593560, 563940 / 2, 609924, 566868 / 2, 644772,
1328                    743708 / 2, 857780, 818972 / 2, 823284, 708384 / 2, 695008,
1329                    730912 / 2, 872096}));
1330 }
1331 
TEST(uKernels,DotprodSparseMatrixBatchVectorMultiplyAccumulate)1332 TEST(uKernels, DotprodSparseMatrixBatchVectorMultiplyAccumulate) {
1333   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 16, 1),
1334               testing::ElementsAre(0));
1335   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 32, 1),
1336               testing::ElementsAre(1240));
1337   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1),
1338               testing::ElementsAre(26544));
1339   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 2),
1340               testing::ElementsAre(26544, 24344));
1341   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(4, 64, 4),
1342               testing::ElementsAreArray(
1343                   {26544, 15866, 22140, 11408, 24344, 53248, 42704, 39900,
1344                    48000, 94146, 101892, 81876, 87712, 105160, 148304, 75936}));
1345 
1346   const bool kNegative = true;
1347   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1, kNegative),
1348               testing::ElementsAre(8764));
1349   EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(2, 64, 2, kNegative),
1350               testing::ElementsAre(8764, 5196, 7204, 11148));
1351 }
1352 
1353 #ifdef __ANDROID__
TEST(uKernels,MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest)1354 TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
1355   // Note we use 29 columns as this exercises all the neon kernel: the
1356   // 16-block SIMD code, the 8-block postamble, and the leftover postamble.
1357   const int a_rows = 4, a_cols = 29;
1358   const int kWeightsPerUint32 = 4;
1359   /* clang-format off */
1360   const float a_float_data[] = {
1361       /* 1st row */
1362       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1363       14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
1364       24.24, 25.25, 26.26, 27.27, 28.28, 0,
1365       /* 2nd row */
1366       -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1367       -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
1368       -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
1369       /* 3rd row */
1370       1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
1371       13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
1372       23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
1373       /* 4th row */
1374       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1375       -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
1376       -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
1377 
1378   int8_t* a_int8_data = reinterpret_cast<int8_t*>(
1379       aligned_malloc(a_rows * a_cols, kWeightsPerUint32));
1380   float a_min, a_max;
1381   float scaling_factor_a;
1382   SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min,
1383                           &a_max, &scaling_factor_a);
1384   const int8_t expected_a_int8_data[] = {
1385     /* 1st row */
1386     5, 10, 15, 20, 25, 30, 35, 40, 44, 45, 50, 54, 59, 64, 68, 73, 77, 82, 86,
1387     91, 95, 100, 104, 109, 113, 118, 122, 127, 0,
1388     /* 2nd row */
1389     -5, -10, -15, -20, -25, -30, -35, -40, -44, -45, -50, -54, -59, -64, -68,
1390     -73, -77, -82, -86, -91, -95, -100, -104, -109, -113, -118, -122, -127, 0,
1391     /* 3rd row */
1392     5, -10, 15, -20, 25, -30, 35, -40, 44, -45, 50, -54, 59, -64, 68, -73, 77,
1393     -82, 86, -91, 95, -100, 104, -109, 113, -118, 122, -127, 0,
1394     /* 4th row */
1395     -5, 10, -15, 20, -25, 30, -35, 40, -44, 45, -50, 54, -59, 64, -68, 73, -77,
1396     82, -86, 91, -95, 100, -104, 109, -113, 118, -122, 127, 0,
1397   };
1398   for (int i = 0; i < a_rows * a_cols; ++i) {
1399     EXPECT_EQ(expected_a_int8_data[i], a_int8_data[i]);
1400   }
1401 
1402   const int b_rows = 29, b_cols = 1, batches = 2;
1403   const float b_float_data[] = {
1404     /* batch 1 */
1405     1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1406     1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1407     1.0,
1408     /* batch 2 */
1409     2.5, -2.1, 3.0, -1.3, 1.3, -1.1, 2.0, -1.7, 1.9, -1.5, 0.5, -0.7, 0.8, -0.3,
1410     2.8, -2.8, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5, 0.7, -1.9,
1411     0.2,
1412   };
1413 
1414   // Quantized values of B:
1415   int8_t b_int8_data[b_rows * b_cols * batches];
1416   float b_min, b_max;
1417   float scaling_factor_b[batches];
1418   SymmetricQuantizeFloats(b_float_data, b_rows * b_cols, b_int8_data, &b_min,
1419                           &b_max, &scaling_factor_b[0]);
1420   SymmetricQuantizeFloats(&b_float_data[b_rows * b_cols], b_rows * b_cols,
1421                           &b_int8_data[b_rows * b_cols], &b_min, &b_max,
1422                           &scaling_factor_b[1]);
1423 
1424   const int8_t expected_b_int8_data[] = {
1425     /* batch 1 */
1426     127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1427     127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1428     127,
1429     /* batch 2 */
1430     106, -89, 127, -55, 55, -47, 85, -72, 80, -64, 21, -30, 34, -13, 119, -119,
1431     47, -97, 80, -80, 89, -21, 102, -4, 42, -106, 30, -80, 8,
1432   };
1433   /* clang-format on */
1434   for (int i = 0; i < b_rows * b_cols * batches; ++i) {
1435     EXPECT_EQ(expected_b_int8_data[i], b_int8_data[i]);
1436   }
1437 
1438   // Full float operation results in:
1439   // -13.69, 13.69, 414.11, -414.11
1440   // -6.325, 6.325, 631.263, -631.263
1441   float c_float_data[a_rows * b_cols * batches];
1442   for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1443     c_float_data[i] = 0.0;
1444   }
1445 
1446   // Testing product.
1447   const float scaling_factor_c[2] = {
1448       scaling_factor_a * scaling_factor_b[0],
1449       scaling_factor_a * scaling_factor_b[1],
1450   };
1451   MatrixBatchVectorMultiplyAccumulate(a_int8_data, a_rows, a_cols, b_int8_data,
1452                                       scaling_factor_c, batches, c_float_data);
1453 
1454   // Assert we obtain the expected recovered float values.
1455   const float expected_c_float_data[] = {
1456       -14.474, 14.474, 414.402, -414.402, -6.92228, 6.92228, 632.042, -632.042,
1457   };
1458   for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1459     EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001);
1460   }
1461 
1462   // Call version of MatrixBatchVectorMultiplyAccumulate that uses
1463   // CpuBackendGemm.
1464   std::vector<int32_t> accum_scratch(a_rows * batches);
1465   std::vector<float> c_float_data_2(a_rows * batches, 0.0);
1466   CpuBackendContext context;
1467   MatrixBatchVectorMultiplyAccumulate(
1468       a_int8_data, a_rows, a_cols, b_int8_data, scaling_factor_c, batches,
1469       accum_scratch.data(), c_float_data_2.data(), &context);
1470 
1471   // Assert (again) we obtain the expected recovered float values.
1472   for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1473     EXPECT_NEAR(expected_c_float_data[i], c_float_data_2[i], 0.001);
1474   }
1475 
1476   aligned_free(a_int8_data);
1477 }
1478 #endif  // __ANDROID__
1479 
TEST(uKernels,SparseMatrixBatchVectorMultiplyAccumulateTest)1480 TEST(uKernels, SparseMatrixBatchVectorMultiplyAccumulateTest) {
1481   const int kRow = 4;
1482   const int kCol = 48;
1483   const int kBatch = 2;
1484   /* clang-format off */
1485   float matrix[kRow * kCol] = {
1486       /* 1st row */
1487       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1488       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1489       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
1490       39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0,
1491       /* 2nd row */
1492       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1493       0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
1494       -25.25, -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1495       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0,
1496       /* 3rd row */
1497       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1498       0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
1499       -26.26, 27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1500       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0,
1501       /* 4th row */
1502       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1503       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1504       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
1505       38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0};
1506 
1507   // BCSR format of the above matrix.
1508   float matrix_values[] = {
1509       /* 1st row */
1510       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1511       14.14, 15.15, 16.16, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, 39.39,
1512       40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0,
1513       /* 2nd row */
1514       -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, -25.25,
1515       -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0,
1516       /* 3rd row */
1517       17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, -26.26,
1518       27.27, -28.28, 0, 0.0, 0.0, 0.0,
1519       /* 4th row */
1520       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1521       -13.13, 14.14, -15.15, 16.16, -33.33, 34.34, -35.35, 36.36, -37.37, 38.38,
1522       -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0};
1523   uint8_t ledger[] = {
1524       2, 0,  2,  // 1st row
1525       1, 1,      // 2nd row
1526       1, 1,      // 3rd row
1527       2, 0,  2   // 4th row
1528   };
1529 
1530   float vector[kBatch * kCol] = {
1531     /* 1st batch */
1532     1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1533     1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1534     1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1535     1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1536     /* 2nd batch */
1537     2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
1538     -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0,
1539     2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5,
1540     0.7, -1.9, 0.2, 0.0, 0.1, 0.2,
1541   };
1542   /* clang-format on */
1543 
1544   std::vector<float> dense_output(kRow * kBatch, 0.0);
1545   MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
1546                                       dense_output.data());
1547 
1548   EXPECT_THAT(dense_output, ElementsAreArray(ArrayFloatNear(
1549                                 {-13.69, 6.06001, 272.7, -608.03, -9.66602,
1550                                  -10.201, 10.201, -713.897949},
1551                                 1e-4)));
1552 
1553   std::vector<float> sparse_output(kRow * kBatch, 0.0);
1554   SparseMatrixBatchVectorMultiplyAccumulate(
1555       matrix_values, ledger, kRow, kCol, vector, kBatch, sparse_output.data());
1556 
1557   EXPECT_THAT(sparse_output,
1558               ElementsAreArray(ArrayFloatNear(dense_output, 1e-4)));
1559 }
1560 
1561 #ifdef __ANDROID__
TEST(uKernels,SparseMatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest)1562 TEST(uKernels,
1563      SparseMatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
1564   const int kRow = 4;
1565   const int kCol = 48;
1566   const int kBatch = 2;
1567   /* clang-format off */
1568   const int8_t quantized_matrix[] = {
1569       /* 1st row */
1570       3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 0, 0, 0, 0,
1571       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, 98, 101, 104, 107, 110, 113, 115,
1572       118, 121, 124, 127, 0, 0, 0, 0,
1573       /* 2nd row */
1574       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -49, -52, -55, -58, -61,
1575       -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1576       0, 0, 0, 0, 0, 0, 0,
1577       /* 3rd row */
1578       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, -52, 55, -58, 61, -64,
1579       66, -69, 72, -75, 78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1580       0, 0, 0, 0,
1581       /* 4th row */
1582       -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, 0, 0,
1583       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -95, 98, -101, 104, -107, 110,
1584       -113, 115, -118, 121, -124, 127, 0, 0, 0, 0,
1585   };
1586   const int8_t quantized_matrix_values[] = {
1587       /* 1st row */
1588       3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 95, 98, 101,
1589       104, 107, 110, 113, 115, 118, 121, 124, 127, 0, 0, 0, 0,
1590       /* 2nd row */
1591       -49, -52, -55, -58, -61, -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0,
1592       /* 3rd row */
1593       49, -52, 55, -58, 61, -64, 66, -69, 72, -75, 78, -81, 0, 0, 0, 0,
1594       /* 4th row */
1595       -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, -95,
1596       98, -101, 104, -107, 110, -113, 115, -118, 121, -124, 127, 0, 0, 0, 0,
1597   };
1598   uint8_t ledger[] = {
1599       2, 0,  2,  // 1st row
1600       1, 1,      // 2nd row
1601       1, 1,      // 3rd row
1602       2, 0,  2   // 4th row
1603   };
1604 
1605   float matrix_scaling_factor = 0.349921;
1606 
1607   const int8_t quantized_vector[] = {
1608       /* 1st batch */
1609       127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127,
1610       -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1611       127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127,
1612       -127, 127, -127, 127, -127, 127, -127, 127, -127,
1613       /* 2nd batch */
1614       106, 0, -89, 0, 127, 0, -55, 0, 55, 0, -47, 0, 85, 0, -72, 0, 80, 0,
1615       -64, 0, 21, 0, -30, 0, 34, 0, -13, 0, 119, 0, -119, 0, 47, -97, 80, -80,
1616       89, -21, 102, -4, 42, -106, 30, -80, 8, 1, 2, 3,
1617   };
1618   float vector_scaling_factor[2] = {0.00787402, 0.023622};
1619 
1620   /* clang-format on */
1621   float result_scaling_factor[2] = {
1622       matrix_scaling_factor * vector_scaling_factor[0],
1623       matrix_scaling_factor * vector_scaling_factor[1],
1624   };
1625   std::vector<float> dense_output(kRow * kBatch, 0.0);
1626   MatrixBatchVectorMultiplyAccumulate(quantized_matrix, kRow, kCol,
1627                                       quantized_vector, result_scaling_factor,
1628                                       kBatch, dense_output.data());
1629 
1630   EXPECT_THAT(dense_output,
1631               ElementsAreArray(ArrayFloatNear(
1632                   {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464,
1633                    -9.381721, 9.381721, -713.845642})));
1634 
1635   std::vector<float> sparse_output(kRow * kBatch, 0.0);
1636   SparseMatrixBatchVectorMultiplyAccumulate(
1637       quantized_matrix_values, ledger, kRow, kCol, quantized_vector,
1638       result_scaling_factor, kBatch, sparse_output.data());
1639 
1640   EXPECT_THAT(sparse_output,
1641               ElementsAreArray(ArrayFloatNear(
1642                   {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464,
1643                    -9.381721, 9.381721, -713.845642})));
1644 }
1645 #endif  // __ANDROID__
1646 
TEST(uKernels,VectorVectorCwiseProductTest)1647 TEST(uKernels, VectorVectorCwiseProductTest) {
1648   constexpr int kVectorSize = 10;
1649   static float input1[kVectorSize] = {0.0,  -0.5, 1.0,  -1.5, 2.0,
1650                                       -2.5, 3.0,  -3.5, 4.0,  -4.5};
1651   static float input2[kVectorSize] = {0.1,  -0.1, 0.1,  -0.1, 0.1,
1652                                       -0.1, 0.1,  -0.1, 0.1,  -0.1};
1653   std::vector<float> output(kVectorSize);
1654   VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data());
1655   EXPECT_THAT(output,
1656               ElementsAreArray(ArrayFloatNear(
1657                   {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45})));
1658 }
1659 
TEST(uKernels,VectorVectorCwiseProductAccumulateTest)1660 TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
1661   constexpr int kVectorSize = 10;
1662   static float input1[kVectorSize] = {0.0,  -0.5, 1.0,  -1.5, 2.0,
1663                                       -2.5, 3.0,  -3.5, 4.0,  -4.5};
1664   static float input2[kVectorSize] = {0.1,  -0.1, 0.1,  -0.1, 0.1,
1665                                       -0.1, 0.1,  -0.1, 0.1,  -0.1};
1666   std::vector<float> output(kVectorSize);
1667   std::fill(output.begin(), output.end(), 1.0);
1668   VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize,
1669                                      output.data());
1670   EXPECT_THAT(output,
1671               ElementsAreArray(ArrayFloatNear(
1672                   {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
1673 }
1674 
TEST(uKernels,VectorBatchVectorAddTest)1675 TEST(uKernels, VectorBatchVectorAddTest) {
1676   constexpr int kVectorSize = 3;
1677   constexpr int kBatchSize = 2;
1678   static float input[kVectorSize] = {0.0, -0.5, 1.0};
1679   std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
1680   VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
1681   EXPECT_THAT(output,
1682               testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
1683 }
1684 
TEST(uKernels,VectorBatchVectorAssignTest)1685 TEST(uKernels, VectorBatchVectorAssignTest) {
1686   constexpr int kVectorSize = 5;
1687   constexpr int kBatchSize = 3;
1688   static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1689   std::vector<float> output(kVectorSize * kBatchSize);
1690   VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data());
1691   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1692                           {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0,
1693                            0.0, -0.5, 1.0, -1.5, 2.0})));
1694 }
1695 
TEST(uKernels,ApplySigmoidToVectorTest)1696 TEST(uKernels, ApplySigmoidToVectorTest) {
1697   constexpr int kVectorSize = 5;
1698   static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1699   std::vector<float> output(kVectorSize);
1700   ApplySigmoidToVector(input, kVectorSize, output.data());
1701   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1702                           {0.5, 0.377541, 0.731059, 0.182426, 0.880797})));
1703 }
1704 
TEST(uKernels,ApplyActivationToVectorTest)1705 TEST(uKernels, ApplyActivationToVectorTest) {
1706   constexpr int kVectorSize = 5;
1707   static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1708   std::vector<float> output(kVectorSize);
1709   ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data());
1710   EXPECT_THAT(output,
1711               ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0})));
1712 
1713   ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data());
1714   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1715                           {0.0, -0.462117, 0.761594, -0.905148, 0.964028})));
1716 }
1717 
TEST(uKernels,Sub1VectorTest)1718 TEST(uKernels, Sub1VectorTest) {
1719   constexpr int kVectorSize = 5;
1720   static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1721   std::vector<float> output(kVectorSize);
1722   Sub1Vector(input, kVectorSize, output.data());
1723   EXPECT_THAT(output,
1724               ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0})));
1725 }
1726 
TEST(uKernels,Sub1VectorInt16Test)1727 TEST(uKernels, Sub1VectorInt16Test) {
1728   constexpr int kVectorSize = 30;
1729   static int16_t input[kVectorSize] = {
1730       32760, 300,   1,     2,    3, 4, 5, 6, 300, 1000,
1731       32767, 32000, 300,   1,    2, 3, 4, 5, 56,  300,
1732       1000,  32767, 32761, 1300, 1, 2, 3, 4, 5,   6,
1733   };
1734   std::vector<int16_t> output(kVectorSize);
1735   Sub1Vector(input, kVectorSize, output.data());
1736   EXPECT_THAT(
1737       output,
1738       testing::ElementsAreArray({
1739           7,     32467, 32766, 32765, 32764, 32763, 32762, 32761, 32467, 31767,
1740           0,     767,   32467, 32766, 32765, 32764, 32763, 32762, 32711, 32467,
1741           31767, 0,     6,     31467, 32766, 32765, 32764, 32763, 32762, 32761,
1742       }));
1743 }
1744 
TEST(uKernels,VectorBatchVectorCwiseProductAccumulateInteger)1745 TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) {
1746   constexpr int kVectorSize = 29;
1747   constexpr int kBatchSize = 4;
1748   static int16_t vector[kVectorSize] = {-10, 9,  8,  7,  6,  5,  4,  3,  2, 1,
1749                                         0,   1,  2,  3,  4,  5,  6,  7,  8, 9,
1750                                         10,  11, 12, 13, 14, 15, 16, 17, 18};
1751   const std::vector<int16_t> batch_vector = {
1752       /* batch 0 */
1753       10, 11, 12, 13, 14, 15, 16, 17, 18, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1,
1754       2, 3, 4, 5, 6, 7, 8, 9,
1755       /* batch 1 */
1756       -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0, 1,
1757       2, 3, 4, 5, 6, 7, 8, 9,
1758       /* batch 2 */
1759       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1760       13, 14, 15, 16, 17, 18,
1761       /* batch 3 */
1762       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1763       13, 14, 15, 16, 17, 18};
1764   std::vector<int16_t> batch_output = {
1765       /* batch 0 */
1766       -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0, 1,
1767       2, 3, 4, 5, 6, 7, 8, 9,
1768       /* batch 1 */
1769       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5,
1770       4, 3, 2, 1, 10, 11, 12,
1771       /* batch 2 */
1772       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1773       13, 14, 15, 16, 17, 18,
1774       /* batch 3 */
1775       10, 11, 12, 13, 14, 15, 16, 17, 18, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1,
1776       13, 14, 15, 16, 17, 18};
1777   // Test with 0.25 scale, which is decomposed into (1073741824, -1).
1778   VectorBatchVectorCwiseProductAccumulate(vector, kVectorSize,
1779                                           batch_vector.data(), kBatchSize,
1780                                           1073741824, -1, batch_output.data());
1781 
1782   const std::vector<int16_t> expected_output = {
1783       /* batch 0 */
1784       -35, 34, 32, 30, 27, 24, 20, 16, 11, -1, 10, 13, 16, 18, 19, 20, 21, 21,
1785       20, 0, 4, 8, 12, 17, 23, 29, 35, 42, 50,
1786       /* batch 1 */
1787       27, 24, 20, 18, 15, 14, 12, 12, 1, 2, 2, 6, 10, 15, 20, 26, 32, 39, 26, 9,
1788       11, 13, 15, 18, 22, 26, 30, 35, 51,
1789       /* batch 2 */
1790       11, 15, 4, 7, 8, 10, 10, 11, 10, 10, 8, 12, -6, 15, 14, 14, 12, 11, 8, 6,
1791       27, 32, 46, 54, 61, 70, 78, 88, 97,
1792       /* batch 3 */
1793       17, 21, 14, 17, 18, 20, 20, 21, 20, 20, 18, -7, 13, 14, 13, 13, 11, 10, 7,
1794       5, 26, 31, 37, 56, 63, 72, 80, 90, 99};
1795   // Only allow 1 element difference for the rounding result.
1796   CompareRoundingResults<int16_t>(4 * 29, expected_output.data(),
1797                                   batch_output.data(), 1, 1);
1798 }
1799 
TEST(uKernels,VectorBatchVectorCwiseProductAccumulateFloat)1800 TEST(uKernels, VectorBatchVectorCwiseProductAccumulateFloat) {
1801   constexpr int kVectorSize = 29;
1802   constexpr int kBatchSize = 4;
1803   static float input[kVectorSize] = {
1804       1.1f,   2.2f,   3.3f,   4.4f,   5.5f,   6.6f,   7.7f,   8.8f,
1805       9.9f,   10.10f, 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f,
1806       17.17f, 18.18f, 19.19f, 20.20f, 21.21f, 22.22f, 23.23f, 24.24f,
1807       25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
1808   std::vector<float> output = {
1809       /* batch 0 */
1810       1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.10f, 11.11f,
1811       12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.20f,
1812       21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
1813       /* batch 1 */
1814       -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.10f,
1815       -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f,
1816       -19.19f, -20.20f, -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f,
1817       -27.27f, -28.28f, 0.0f,
1818       /* batch 2 */
1819       1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.10f, 11.11f,
1820       -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f,
1821       -20.20f, 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f,
1822       -28.28f, 0.0f,
1823       /* batch 3 */
1824       -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.10f,
1825       -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f,
1826       -19.19f, 20.20f, -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f,
1827       -27.27f, 28.28f, 0.0f};
1828   VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
1829                                           kBatchSize, output.data());
1830 
1831   // Expect output = input * output + output.
1832   const std::vector<float> expected_output = {
1833       /* batch 0 */
1834       2.31f, 7.04f, 14.19f, 23.76f, 35.75f, 50.16f, 66.99f, 86.24f, 107.91f,
1835       112.11f, 134.5421f, 159.0144f, 185.5269f, 214.0796f, 244.6725f, 277.3056f,
1836       311.9789f, 348.6924f, 387.4461f, 428.24f, 471.0741f, 515.9484f, 562.8629f,
1837       611.8176f, 662.8125f, 715.8476f, 770.9229f, 828.0384f, 0.0f,
1838       /* batch 1 */
1839       -2.31f, -7.04f, -14.19f, -23.76f, -35.75f, -50.16f, -66.99f, -86.24f,
1840       -107.91f, -112.11f, -134.5421f, -159.0144f, -185.5269f, -214.0796f,
1841       -244.6725f, -277.3056f, -311.9789f, -348.6924f, -387.4461f, -428.24f,
1842       -471.0741f, -515.9484f, -562.8629f, -611.8176f, -662.8125f, -715.8476f,
1843       -770.9229f, -828.0384f, 0.0f,
1844       /* batch 2 */
1845       2.31f, -7.04f, 14.19f, -23.76f, 35.75f, -50.16f, 66.99f, -86.24f, 107.91f,
1846       -112.11f, 134.5421f, -159.0144f, 185.5269f, -214.0796f, 244.6725f,
1847       -277.3056f, 311.9789f, -348.6924f, 387.4461f, -428.24f, 471.0741f,
1848       -515.9484f, 562.8629f, -611.8176f, 662.8125f, -715.8476f, 770.9229f,
1849       -828.0384f, 0.0f,
1850       /* batch 3 */
1851       -2.31f, 7.04f, -14.19f, 23.76f, -35.75f, 50.16f, -66.99f, 86.24f,
1852       -107.91f, 112.11f, -134.5421f, 159.0144f, -185.5269f, 214.0796f,
1853       -244.6725f, 277.3056f, -311.9789f, 348.6924f, -387.4461f, 428.24f,
1854       -471.0741f, 515.9484f, -562.8629f, 611.8176f, -662.8125f, 715.8476f,
1855       -770.9229f, 828.0384f, 0.0f};
1856   EXPECT_THAT(output, testing::ElementsAreArray(
1857                           ArrayFloatNear(expected_output, 6.5e-5f)));
1858 }
1859 
TEST(uKernels,VectorBatchVectorCwiseProductNoAccumulate)1860 TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
1861   constexpr int kVectorSize = 29;
1862   constexpr int kBatchSize = 4;
1863   static float input[kVectorSize] = {
1864       1.1,   2.2,   3.3,   4.4,   5.5,   6.6,   7.7,   8.8,   9.9,   10.1,
1865       11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
1866       21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
1867   std::vector<float> output = {
1868       /* batch 0 */
1869       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1870       14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
1871       24.24, 25.25, 26.26, 27.27, 28.28, 0,
1872       /* batch 1 */
1873       -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1874       -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
1875       -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
1876       /* batch 2 */
1877       1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
1878       13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
1879       23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
1880       /* batch 3 */
1881       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1882       -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
1883       -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
1884   VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
1885                                 output.data());
1886 
1887   // Expect output = input * output + output.
1888   const std::vector<float> expected_output = {
1889       /* batch 0 */
1890       1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
1891       77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
1892       199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
1893       408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
1894       689.587585, 743.652954, 799.758423, 0.000000,
1895       /* batch 1 */
1896       -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
1897       -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
1898       -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
1899       -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
1900       -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
1901       -799.758423, 0.000000,
1902       /* batch 2 */
1903       1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
1904       59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
1905       172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
1906       368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
1907       637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
1908       /* batch 3 */
1909       -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
1910       -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
1911       -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
1912       -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
1913       -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
1914   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
1915 }
1916 
TEST(uKernels,BatchVectorBatchVectorDotProductTest)1917 TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
1918   constexpr int kVectorSize = 5;
1919   constexpr int kBatch = 2;
1920   static float input1[kVectorSize * kBatch] = {0.0,  -0.5, 1.0,  -1.5, 2.0,
1921                                                -2.5, 3.0,  -3.5, 4.0,  -4.5};
1922   static float input2[kVectorSize * kBatch] = {0.1,  -0.1, 0.1,  -0.1, 0.1,
1923                                                -0.1, 0.1,  -0.1, 0.1,  -0.1};
1924   std::vector<float> output(kBatch);
1925   BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
1926                                    output.data());
1927   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75})));
1928 }
1929 
TEST(uKernels,BatchVectorBatchVectorDotProductIntegerTest)1930 TEST(uKernels, BatchVectorBatchVectorDotProductIntegerTest) {
1931   constexpr int kVectorSize = 5;
1932   constexpr int kBatch = 2;
1933   static int16_t input1[kVectorSize * kBatch] = {0,   5,  10,  -15, 20,
1934                                                  -25, 30, -35, 40,  -45};
1935   static int16_t input2[kVectorSize * kBatch] = {1,  -1, 1,  -1, 1,
1936                                                  -1, 1,  -1, 1,  1};
1937   std::vector<int32_t> output(kBatch);
1938   BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
1939                                    output.data());
1940   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({40, 85})));
1941 }
1942 
TEST(uKernels,ReductionSumVectorTest)1943 TEST(uKernels, ReductionSumVectorTest) {
1944   constexpr int kInputVectorSize = 10;
1945   constexpr int kOutputVectorSize1 = 5;
1946   constexpr int kReductionSize1 = 2;
1947   static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
1948                                           0.0, -0.5, 1.0, 1.0,  2.0};
1949   std::vector<float> result1(kOutputVectorSize1);
1950   ReductionSumVector(input, result1.data(), kOutputVectorSize1,
1951                      kReductionSize1);
1952   EXPECT_THAT(result1,
1953               ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0})));
1954 
1955   constexpr int kOutputVectorSize2 = 2;
1956   constexpr int kReductionSize2 = 5;
1957   std::vector<float> result2(kOutputVectorSize2);
1958   ReductionSumVector(input, result2.data(), kOutputVectorSize2,
1959                      kReductionSize2);
1960   EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
1961 }
1962 
TEST(uKernels,ReductionSumVectorIntegerTest)1963 TEST(uKernels, ReductionSumVectorIntegerTest) {
1964   constexpr int kInputVectorSize = 10;
1965   constexpr int kOutputVectorSize1 = 5;
1966   constexpr int kReductionSize1 = 2;
1967   static int32_t input[kInputVectorSize] = {1, 2, 1, 5, -3, 2, 1, 2, 5, 10};
1968   std::vector<int32_t> result1(kOutputVectorSize1);
1969   ReductionSumVector(input, result1.data(), kOutputVectorSize1,
1970                      kReductionSize1);
1971   EXPECT_THAT(result1, testing::ElementsAreArray({3, 6, -1, 3, 15}));
1972 }
1973 
1974 void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
1975                           const int8_t* recurrent, int8_t recurrent_zp,
1976                           int32_t input_effective_scale_a,
1977                           int32_t input_effective_scale_b,
1978                           int32_t recurrent_effective_scale_a,
1979                           int32_t recurrent_effective_scale_b, int32_t n_batch,
1980                           int32_t n_cell, int16_t* output);
1981 
TEST(uKernels,TwoGateSaturateAddTest)1982 TEST(uKernels, TwoGateSaturateAddTest) {
1983   const std::vector<int8_t> input1 = {1, 2, 3, 4, 55, 66, 77};
1984   const std::vector<int8_t> input2 = {100, 2, 3, 4, 55, 66, 77};
1985   const int32_t input1_zp = 10;
1986   const int32_t input2_zp = -5;
1987   const int32_t multiplier1 = 1347771520;
1988   const int32_t shift1 = -7;
1989   const int32_t multiplier2 = 1047577121;
1990   const int32_t shift2 = -6;
1991   std::vector<int16_t> output(7);
1992 
1993   TwoGateSaturatingAdd(input1.data(), input1_zp, input2.data(), input2_zp,
1994                        multiplier1, shift1, multiplier2, shift2, 1, 7,
1995                        output.data());
1996 
1997   const std::vector<int16_t> expected_output = {1, 0, 0, 0, 0, 1, 1};
1998   EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
1999 }
2000 
2001 namespace {
2002 // Parameterized test: mean, difference, tolerance.
2003 // Input is constructed as [mean-2*diff, mean-diff, mean+diff, mean+2*diff]
2004 class MeanStddevNormalizationTest
2005     : public testing::TestWithParam<std::tuple<float, float, float>> {};
2006 }  // namespace
2007 
TEST_P(MeanStddevNormalizationTest,SeparateBatches)2008 TEST_P(MeanStddevNormalizationTest, SeparateBatches) {
2009   const float mean = std::get<0>(GetParam());
2010   const float diff = std::get<1>(GetParam());
2011   const float tolerance = std::get<2>(GetParam());
2012 
2013   constexpr int kVectorSize = 4;
2014   const float input[kVectorSize] = {mean - 2 * diff, mean - diff, mean + diff,
2015                                     mean + 2 * diff};
2016   float output[kVectorSize];
2017   MeanStddevNormalization(input, output, kVectorSize, 1);
2018   std::vector<float> expected_output;
2019   if (diff == 0.0f) {
2020     expected_output.assign({0.0f, 0.0f, 0.0f, 0.0f});
2021   } else {
2022     const float ksqrt16 = std::sqrt(1.6f);
2023     const float ksqrt04 = std::sqrt(0.4f);
2024     expected_output.assign({-ksqrt16, -ksqrt04, ksqrt04, ksqrt16});
2025   }
2026   EXPECT_THAT(output, testing::ElementsAreArray(
2027                           ArrayFloatNear(expected_output, tolerance)));
2028 }
2029 
2030 INSTANTIATE_TEST_SUITE_P(
2031     uKernels, MeanStddevNormalizationTest,
2032     testing::Values(
2033         std::make_tuple(0.0f, 0.0f, 0.0f),         // zero mean, zero variance
2034         std::make_tuple(0.0f, 0.01f, 2.53e-5f),    // zero mean, small variance
2035         std::make_tuple(0.0f, 100.0f, 1.20e-7f),   // zero mean, large variance
2036         std::make_tuple(0.01f, 0.0f, 0.0f),        // small mean, zero variance
2037         std::make_tuple(0.01f, 0.01f, 2.53e-5f),   // small mean, small variance
2038         std::make_tuple(0.01f, 100.0f, 1.20e-7f),  // small mean, large variance
2039         std::make_tuple(100.0f, 0.0f, 0.0f),       // large mean, zero variance
2040         std::make_tuple(100.0f, 0.01f, 1.81e-4f),  // large mean, small variance
2041         std::make_tuple(100.0f, 100.0f, 1.20e-7f)  // large mean, large variance
2042         ));
2043 
TEST(uKernels,MeanStddevNormalizationAllBatches)2044 TEST(uKernels, MeanStddevNormalizationAllBatches) {
2045   constexpr int kVectorSize = 4;
2046   constexpr int kBatchSize = 9;
2047 
2048   // None-zero input.
2049   static float input[kVectorSize * kBatchSize] = {
2050       0.0f,     0.0f,    0.0f,    0.0f,     // zero mean, zero variance
2051       -0.02f,   -0.01f,  0.01f,   0.02f,    // zero mean, small variance
2052       -200.0f,  -100.0f, 100.0f,  200.0f,   // zero mean, large variance
2053       0.01f,    0.01f,   0.01f,   0.01f,    // small mean, zero variance
2054       -0.01f,   0.0f,    0.02f,   0.03f,    // small mean, small variance
2055       -199.99f, -99.99f, 100.01f, 200.01f,  // small mean, large variance
2056       100.0f,   100.0f,  100.0f,  100.0f,   // large mean, zero variance
2057       99.98f,   99.99f,  100.01f, 100.02f,  // large mean, small variance
2058       -100.0f,  0.0f,    200.0f,  300.0f,   // large mean, large variance
2059   };
2060   float output[kVectorSize * kBatchSize];
2061   MeanStddevNormalization(input, output, kVectorSize, kBatchSize);
2062   const float ksqrt16 = std::sqrt(1.6f);
2063   const float ksqrt04 = std::sqrt(0.4f);
2064   const std::vector<float> expected_output = {
2065       0.0f,     0.0f,     0.0f,    0.0f,     // zero mean, zero variance
2066       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // zero mean, small variance
2067       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // zero mean, large variance
2068       0.0f,     0.0f,     0.0f,    0.0f,     // small mean, zero variance
2069       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // small mean, small variance
2070       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // small mean, large variance
2071       0.0f,     0.0f,     0.0f,    0.0f,     // large mean, zero variance
2072       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // large mean, small variance
2073       -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // large mean, large variance
2074   };
2075   EXPECT_THAT(output, testing::ElementsAreArray(
2076                           ArrayFloatNear(expected_output, 1.81e-4f)));
2077 }
2078 
TEST(uKernels,MeanStddevNormalizationLargeVector)2079 TEST(uKernels, MeanStddevNormalizationLargeVector) {
2080   const float mean = 100.0f;
2081   const float diff = 1.0f;
2082   // Some large vector that is not a round multiple of any SIMD vector sizes.
2083   // Note this is odd.
2084   constexpr int kVectorSize = 16 * 16 + 16 + 1;
2085 
2086   float input[kVectorSize];
2087   // First input is mean.
2088   input[0] = mean;
2089   // Rest is alternating between mean + diff and mean - diff.
2090   for (int i = 1; i < kVectorSize - 1; i += 2) {
2091     input[i + 0] = mean + diff;
2092     input[i + 1] = mean - diff;
2093   }
2094   float output[kVectorSize];
2095   MeanStddevNormalization(input, output, kVectorSize, 1);
2096 
2097   float expected_output[kVectorSize];
2098   // First output should be 0.
2099   expected_output[0] = 0.0;
2100   // Rest should be alternating between ±√(N/(N-1)).
2101   const float expected_elem = std::sqrt(static_cast<double>(kVectorSize) /
2102                                         static_cast<double>(kVectorSize - 1));
2103   for (int i = 1; i < kVectorSize - 1; i += 2) {
2104     expected_output[i + 0] = +expected_elem;
2105     expected_output[i + 1] = -expected_elem;
2106   }
2107   EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), expected_output));
2108 }
2109 
2110 }  // namespace tensor_utils
2111 }  // namespace tflite
2112 
2113 #ifdef DOTPROD_BENCHMARKS
2114 
2115 // Compile with --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" and
2116 // --copt="-DDOTPROD_BENCHMARKS"
2117 // Run with --benchmarks=all
BM_DotprodBatchOneMultiply(benchmark::State & state)2118 void BM_DotprodBatchOneMultiply(benchmark::State& state) {
2119   const int rows = state.range(0);
2120   const int cols = state.range(1);
2121   const int batch = state.range(2);
2122   const int copies = state.range(3);
2123 
2124   // For some benchmarks we make multiple matrix copies. This allows us to
2125   // measure the performance differences of being entirely in cache vs.
2126   // out of cache.
2127   std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2128   for (int i = 0; i < copies; i++) {
2129     datas.push_back(
2130         tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2131   }
2132 
2133   int copy = 0;
2134   for (auto _ : state) {
2135     copy = (copy + 1) % datas.size();
2136     auto& data = datas[copy];
2137     for (int i = 0; i < batch; i++) {
2138       tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2139           data.matrix.data(), data.rows, data.cols,
2140           data.vectors.data() + (data.cols * i), data.scale_factors.data(), 1,
2141           &data.results[0], 1);
2142       testing::DoNotOptimize(data.results[2]);
2143     }
2144   }
2145 }
2146 BENCHMARK(BM_DotprodBatchOneMultiply)
2147     ->Args({16, 16, 1, 1})
2148     ->Args({16, 16, 4, 1})
2149     ->Args({32, 32, 1, 1})
2150     ->Args({32, 32, 4, 1})
2151     ->Args({64, 64, 1, 1})
2152     ->Args({64, 64, 4, 1})
2153     ->Args({128, 128, 1, 1})
2154     ->Args({128, 128, 4, 1})
2155     ->Args({992, 992, 1, 1})
2156     ->Args({992, 992, 8, 1})
2157     ->Args({1024, 1024, 1, 1})
2158     ->Args({1024, 1024, 1, 8})
2159     ->Args({1024, 1024, 4, 1})
2160     ->Args({1024, 1024, 4, 8})
2161     ->Args({1024, 1024, 8, 1})
2162     ->Args({640, 2048, 1, 1})
2163     ->Args({640, 2048, 4, 1})
2164     ->Args({640, 2048, 8, 1})
2165     ->Args({640, 2048, 8, 8})
2166     ->Args({2048, 2048, 1, 1})
2167     ->Args({2048, 2048, 1, 8})
2168     ->Args({2048, 2048, 8, 1});
2169 
BM_DotprodBatchFourMultiply(benchmark::State & state)2170 void BM_DotprodBatchFourMultiply(benchmark::State& state) {
2171   const int rows = state.range(0);
2172   const int cols = state.range(1);
2173   const int batch = state.range(2);
2174   const int copies = state.range(3);
2175 
2176   // For some benchmarks we make multiple matrix copies. This allows us to
2177   // measure the performance differences of being entirely in cache vs.
2178   // out of cache.
2179   std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2180   for (int i = 0; i < copies; i++) {
2181     datas.push_back(
2182         tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2183   }
2184 
2185   int copy = 0;
2186   for (auto _ : state) {
2187     copy = (copy + 1) % datas.size();
2188     auto& data = datas[copy];
2189     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2190         data.matrix.data(), data.rows, data.cols, data.vectors.data(),
2191         data.scale_factors.data(), data.batch, &data.results[0], 1);
2192     testing::DoNotOptimize(data.results[2]);
2193   }
2194 }
2195 BENCHMARK(BM_DotprodBatchFourMultiply)
2196     ->Args({16, 16, 4, 1})
2197     ->Args({32, 32, 4, 1})
2198     ->Args({64, 64, 4, 1})
2199     ->Args({64, 256, 64, 1})
2200     ->Args({64, 256, 256, 1})
2201     ->Args({64, 256, 1024, 1})
2202     ->Args({64, 256, 12544, 1})
2203     ->Args({128, 128, 2, 1})
2204     ->Args({128, 128, 3, 1})
2205     ->Args({128, 128, 4, 1})
2206     ->Args({128, 128, 5, 1})
2207     ->Args({640, 640, 4, 1})
2208     ->Args({992, 992, 8, 1})
2209     ->Args({1024, 1024, 2, 1})
2210     ->Args({1024, 1024, 3, 1})
2211     ->Args({1024, 1024, 4, 1})
2212     ->Args({1024, 1024, 5, 1})
2213     ->Args({1024, 1024, 8, 1})
2214     ->Args({1024, 1024, 8, 8})
2215     ->Args({1024, 1024, 256, 1})
2216     ->Args({640, 2048, 2, 1})
2217     ->Args({640, 2048, 3, 1})
2218     ->Args({640, 2048, 4, 1})
2219     ->Args({640, 2048, 4, 8})
2220     ->Args({640, 2048, 8, 1})
2221     ->Args({2048, 2048, 3, 1})
2222     ->Args({2048, 2048, 4, 1})
2223     ->Args({2048, 2048, 4, 8})
2224     ->Args({2048, 2048, 5, 1})
2225     ->Args({2048, 2048, 8, 1});
2226 
BM_DotprodSparseMultiply(benchmark::State & state)2227 void BM_DotprodSparseMultiply(benchmark::State& state) {
2228   const int rows = state.range(0);
2229   const int cols = state.range(1);
2230   const int batch = state.range(2);
2231 
2232   const int copies = state.range(3);
2233 
2234   // For some benchmarks we make multiple matrix copies. This allows us to
2235   // measure the performance differences of being entirely in cache vs.
2236   // out of cache.
2237   std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2238   for (int i = 0; i < copies; i++) {
2239     datas.push_back(
2240         tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2241   }
2242 
2243   int copy = 0;
2244   for (auto _ : state) {
2245     copy = (copy + 1) % datas.size();
2246     auto& data = datas[copy];
2247     tflite::tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
2248         data.sparse_matrix.data(), data.ledger.data(), data.rows, data.cols,
2249         data.vectors.data(), data.scale_factors.data(), data.batch,
2250         &data.results[0]);
2251     testing::DoNotOptimize(data.results[2]);
2252   }
2253 }
2254 BENCHMARK(BM_DotprodSparseMultiply)
2255     ->Args({128, 128, 1, 1})
2256     ->Args({128, 128, 4, 1})
2257     ->Args({640, 640, 4, 1})
2258     ->Args({992, 992, 8, 1})
2259     ->Args({1024, 1024, 1, 1})
2260     ->Args({1024, 1024, 4, 1})
2261     ->Args({1024, 1024, 8, 1})
2262     ->Args({640, 2048, 1, 1})
2263     ->Args({640, 2048, 4, 1})
2264     ->Args({640, 2048, 8, 1})
2265     ->Args({2048, 2048, 1, 1})
2266     ->Args({2048, 2048, 8, 1});
2267 
2268 #endif  // DOTPROD_BENCHMARKS
2269