1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
16
17 #include <math.h>
18
19 #include <gmock/gmock.h>
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/kernels/cpu_backend_context.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/test_util.h"
25
26 #ifdef DOTPROD_BENCHMARKS
27 #include "testing/base/public/benchmark.h"
28 #endif // DOTPROD_BENCHMARKS
29
30 namespace tflite {
31 namespace tensor_utils {
32
33 // Normally we should require bit-for-bit exact results. Unfortunately a bug
34 // in the Intel arm_neon_sse.h translation header that we use for x86 tests
35 // causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes
36 // off-by-1 errors. So we have to live with a
37 // few off-by-one errors for now, yet still ensure that no more than a small
38 // minority of values are wrong.
39 // This util is to compare the rounding results for integer-output.
40 template <typename T>
CompareRoundingResults(int flat_size,const T * expected_result,const T * real_result,int max_element_tolerance=1,int max_total_tolerance=5)41 void CompareRoundingResults(int flat_size, const T* expected_result,
42 const T* real_result, int max_element_tolerance = 1,
43 int max_total_tolerance = 5) {
44 int max_diff = 0;
45 int64_t total_diff = 0;
46 for (int i = 0; i < flat_size; i++) {
47 int diff = static_cast<int>(std::abs(expected_result[i] - real_result[i]));
48 total_diff += diff;
49 max_diff = std::max(max_diff, diff);
50 }
51
52 EXPECT_LE(max_diff, max_element_tolerance);
53 EXPECT_LE(total_diff, max_total_tolerance);
54 }
55
TEST(uKernels,FloorLog2Test)56 TEST(uKernels, FloorLog2Test) {
57 for (int i = 1; i < 257; ++i) {
58 EXPECT_EQ(::tflite::FloorLog2(i),
59 static_cast<int>(std::floor(std::log2(i))));
60 }
61 }
62
TEST(uKernels,VectorScalarMultiply)63 TEST(uKernels, VectorScalarMultiply) {
64 constexpr int kVectorSize = 29;
65 static int8_t input[kVectorSize];
66 for (int i = 0; i < 29; ++i) {
67 input[i] = static_cast<int8_t>(i - 14);
68 }
69 const float scale = 0.1f;
70 std::vector<float> output(kVectorSize, 0.0f);
71 VectorScalarMultiply(input, kVectorSize, scale, output.data());
72 EXPECT_THAT(output,
73 ElementsAreArray(ArrayFloatNear(
74 {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5,
75 -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5,
76 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4})));
77 }
78
79 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
80
81 // Test if a float array if full of zero values.
TEST(uKernels,IsZeroFloatTest)82 TEST(uKernels, IsZeroFloatTest) {
83 // Single NEON vector (= 4 floats)
84 {
85 const float four_zeros[4] = {0, 0, 0, 0};
86 EXPECT_TRUE(IsZeroVector(four_zeros, ARRAY_SIZE(four_zeros)));
87 }
88 {
89 const float four_nonzeros[4] = {1, 2, 3, 4};
90 EXPECT_FALSE(IsZeroVector(four_nonzeros, ARRAY_SIZE(four_nonzeros)));
91 }
92 // Multiple NEON vectors
93 {
94 const float eight_zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0};
95 EXPECT_TRUE(IsZeroVector(eight_zeros, ARRAY_SIZE(eight_zeros)));
96 }
97 {
98 const float eight_nonzeros[8] = {1, 2, 3, 4, 5, 6, 7, 8};
99 EXPECT_FALSE(IsZeroVector(eight_nonzeros, ARRAY_SIZE(eight_nonzeros)));
100 }
101 {
102 const float multiple_four_mixed1[8] = {0, 0, 0, 0, 5, 6, 7, 8};
103 EXPECT_FALSE(
104 IsZeroVector(multiple_four_mixed1, ARRAY_SIZE(multiple_four_mixed1)));
105 }
106 {
107 const float multiple_four_mixed2[8] = {1, 2, 3, 4, 0, 0, 0, 0};
108 EXPECT_FALSE(
109 IsZeroVector(multiple_four_mixed2, ARRAY_SIZE(multiple_four_mixed2)));
110 }
111 // less than one NEON vector
112 {
113 const float three_zeros[3] = {0, 0, 0};
114 EXPECT_TRUE(IsZeroVector(three_zeros, ARRAY_SIZE(three_zeros)));
115 }
116 {
117 const float three_nonzeros[3] = {1, 2, 3};
118 EXPECT_FALSE(IsZeroVector(three_nonzeros, ARRAY_SIZE(three_nonzeros)));
119 }
120 {
121 const float three_mixed[3] = {1, 0, 3};
122 EXPECT_FALSE(IsZeroVector(three_mixed, ARRAY_SIZE(three_mixed)));
123 }
124 // Postamble after NEON vectors
125 {
126 const float seven_zeros[7] = {0, 0, 0, 0, 0, 0, 0};
127 EXPECT_TRUE(IsZeroVector(seven_zeros, ARRAY_SIZE(seven_zeros)));
128 }
129 {
130 const float seven_nonzeros[7] = {1, 2, 3, 4, 5, 6, 7};
131 EXPECT_FALSE(IsZeroVector(seven_nonzeros, ARRAY_SIZE(seven_nonzeros)));
132 }
133 {
134 const float nonzeros_after_zeros[7] = {0, 0, 0, 0, 5, 6, 7};
135 EXPECT_FALSE(
136 IsZeroVector(nonzeros_after_zeros, ARRAY_SIZE(nonzeros_after_zeros)));
137 }
138 }
139
140 // Test if an int8 array if full of zero values.
TEST(uKernels,IsZeroInt8Test)141 TEST(uKernels, IsZeroInt8Test) {
142 // Single NEON vector (= 16x int8_t)
143 {
144 const int8_t sixteen_zeros[16] = {0, 0, 0, 0, 0, 0, 0, 0,
145 0, 0, 0, 0, 0, 0, 0, 0};
146 EXPECT_TRUE(IsZeroVector(sixteen_zeros, ARRAY_SIZE(sixteen_zeros)));
147 }
148 {
149 const int8_t sixteen_nonzeros[16] = {1, 2, 3, 4, 5, 6, 7, 8,
150 9, 10, 11, 12, 13, 14, 15, 16};
151 EXPECT_FALSE(IsZeroVector(sixteen_nonzeros, ARRAY_SIZE(sixteen_nonzeros)));
152 }
153 // Multiple NEON vectors
154 {
155 const int8_t thritytwo_zeros[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
156 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
158 EXPECT_TRUE(IsZeroVector(thritytwo_zeros, ARRAY_SIZE(thritytwo_zeros)));
159 }
160 {
161 const int8_t thritytwo_nonzeros[32] = {
162 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
163 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
164 EXPECT_FALSE(
165 IsZeroVector(thritytwo_nonzeros, ARRAY_SIZE(thritytwo_nonzeros)));
166 }
167 {
168 const int8_t thritytwo_mixed1[32] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
169 12, 13, 14, 15, 16, 0, 0, 0, 0, 0, 0,
170 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
171 EXPECT_FALSE(IsZeroVector(thritytwo_mixed1, ARRAY_SIZE(thritytwo_mixed1)));
172 }
173 {
174 const int8_t thritytwo_mixed2[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
175 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6,
176 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
177 EXPECT_FALSE(IsZeroVector(thritytwo_mixed2, ARRAY_SIZE(thritytwo_mixed2)));
178 }
179 // less than one NEON vector
180 {
181 const int8_t fifteen_zeros[15] = {0, 0, 0, 0, 0, 0, 0, 0,
182 0, 0, 0, 0, 0, 0, 0};
183 EXPECT_TRUE(IsZeroVector(fifteen_zeros, ARRAY_SIZE(fifteen_zeros)));
184 }
185 {
186 const int8_t fifteen_nonzeros[15] = {1, 2, 3, 4, 5, 6, 7, 8,
187 9, 10, 11, 12, 13, 14, 15};
188 EXPECT_FALSE(IsZeroVector(fifteen_nonzeros, ARRAY_SIZE(fifteen_nonzeros)));
189 }
190 {
191 const int8_t fifteen_mixed[15] = {1, 0, 3, 0, 5, 0, 7, 0,
192 9, 0, 11, 0, 13, 0, 15};
193 EXPECT_FALSE(IsZeroVector(fifteen_mixed, ARRAY_SIZE(fifteen_mixed)));
194 }
195 // Postamble after NEON vectors
196 {
197 const int8_t seventeen_zeros[17] = {0, 0, 0, 0, 0, 0, 0, 0,
198 0, 0, 0, 0, 0, 0, 0};
199 EXPECT_TRUE(IsZeroVector(seventeen_zeros, ARRAY_SIZE(seventeen_zeros)));
200 }
201 {
202 const int8_t seventeen_nonzeros[17] = {1, 2, 3, 4, 5, 6, 7, 8, 9,
203 10, 11, 12, 13, 14, 15, 16, 17};
204 EXPECT_FALSE(
205 IsZeroVector(seventeen_nonzeros, ARRAY_SIZE(seventeen_nonzeros)));
206 }
207 {
208 const int8_t nonzeros_after_zeros[17] = {0, 0, 0, 0, 0, 0, 0, 0,
209 0, 0, 0, 0, 0, 0, 17};
210 EXPECT_FALSE(
211 IsZeroVector(nonzeros_after_zeros, ARRAY_SIZE(nonzeros_after_zeros)));
212 }
213 }
214
215 #undef ARRAY_SIZE
216
TEST(uKernels,SymmetricQuantizeFloatsTest)217 TEST(uKernels, SymmetricQuantizeFloatsTest) {
218 constexpr int kVectorSize = 9;
219 static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0,
220 -5.0, -10.0, 0.0, 1000.0};
221
222 int8_t output[kVectorSize];
223 float min, max, scaling_factor;
224 SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
225 &scaling_factor);
226
227 EXPECT_EQ(min, -640);
228 EXPECT_EQ(max, 1000);
229 // EQ won't work due to fpoint.
230 EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6);
231 EXPECT_THAT(output,
232 testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
233 }
234
TEST(uKernels,SymmetricQuantizeFloatsAllZerosTest)235 TEST(uKernels, SymmetricQuantizeFloatsAllZerosTest) {
236 constexpr int kVectorSize = 9;
237 static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
238
239 int8_t output[kVectorSize];
240 float min, max, scaling_factor;
241 SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
242 &scaling_factor);
243
244 EXPECT_EQ(min, 0);
245 EXPECT_EQ(max, 0);
246 EXPECT_EQ(scaling_factor, 1);
247 EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
248 }
249
TEST(uKernels,SymmetricQuantizeFloatsAllAlmostZeroTest)250 TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
251 constexpr int kVectorSize = 9;
252 static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
253 4e-5, 9e-6, 2e-4, 0};
254
255 int8_t output[kVectorSize];
256 float min, max, scaling_factor;
257 SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
258 &scaling_factor);
259
260 EXPECT_NEAR(min, -9e-05, 1e-6);
261 EXPECT_NEAR(max, 0.0002, 1e-6);
262 EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6);
263 EXPECT_THAT(output,
264 testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
265 }
266
TEST(uKernels,AsymmetricQuantizeFloatsTest)267 TEST(uKernels, AsymmetricQuantizeFloatsTest) {
268 constexpr int kVectorSize = 9;
269 static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0,
270 -5.0, -10.0, 0.0, 1000.0};
271 int8_t output[kVectorSize];
272 double min = -640.0;
273 double max = 1000.0;
274 QuantizationParams quantization_params =
275 ChooseQuantizationParams<int8_t>(min, max);
276 float scale = quantization_params.scale;
277 int32_t offset = quantization_params.zero_point;
278 float test_scale;
279 int32_t test_offset;
280 AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
281 &test_offset);
282 // EQ won't work due to fpoint.
283 EXPECT_NEAR(test_scale, scale, 1e-6);
284 EXPECT_EQ(test_offset, offset);
285 EXPECT_THAT(output, testing::ElementsAreArray(
286 {-128, -127, -126, -26, -28, -29, -30, -28, 127}));
287 }
288
TEST(uKernels,AsymmetricQuantizeFloatsAllZerosTest)289 TEST(uKernels, AsymmetricQuantizeFloatsAllZerosTest) {
290 constexpr int kVectorSize = 9;
291 static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
292 int8_t output[kVectorSize];
293 float test_scale;
294 int32_t test_offset;
295 AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
296 &test_offset);
297 EXPECT_EQ(test_scale, 1);
298 EXPECT_EQ(test_offset, 0);
299 EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
300 }
301
TEST(uKernels,AsymmetricQuantizeFloatsZeroRangeTest)302 TEST(uKernels, AsymmetricQuantizeFloatsZeroRangeTest) {
303 constexpr int kVectorSize = 9;
304 static float input[kVectorSize] = {2000, 2000, 2000, 2000, 2000,
305 2000, 2000, 2000, 2000};
306 int8_t output[kVectorSize];
307 double min = 0;
308 double max = 2000;
309 QuantizationParams quantization_params =
310 ChooseQuantizationParams<int8_t>(min, max);
311 int32_t offset = quantization_params.zero_point;
312 float scale = quantization_params.scale;
313 float test_scale;
314 int32_t test_offset;
315 AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
316 &test_offset);
317 EXPECT_NEAR(test_scale, scale, 1e-6);
318 EXPECT_EQ(test_offset, offset);
319 EXPECT_THAT(output, testing::ElementsAreArray(
320 {127, 127, 127, 127, 127, 127, 127, 127, 127}));
321 }
322
TEST(uKernels,AsymmetricQuantizeFloatsAllAlmostZeroTest)323 TEST(uKernels, AsymmetricQuantizeFloatsAllAlmostZeroTest) {
324 constexpr int kVectorSize = 9;
325 static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
326 4e-5, 9e-6, 2e-4, 0};
327 int8_t output[kVectorSize];
328 double min = -9e-05;
329 double max = 0.0002;
330 QuantizationParams quantization_params =
331 ChooseQuantizationParams<int8_t>(min, max);
332 int32_t offset = quantization_params.zero_point;
333 float scale = quantization_params.scale;
334 float test_scale;
335 int32_t test_offset;
336 AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
337 &test_offset);
338 EXPECT_NEAR(test_scale, scale, 1e-6);
339 EXPECT_EQ(test_offset, offset);
340 EXPECT_THAT(output, testing::ElementsAreArray(
341 {-58, -23, -55, -128, -48, -14, -41, 127, -49}));
342 }
343
TEST(uKernels,MatrixBatchVectorMultiplyAccumulateTest)344 TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
345 constexpr int kRow = 3;
346 constexpr int kCol = 4;
347 constexpr int kBatch = 2;
348 static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, //
349 -1.0, -2.0, -3.0, -4.0, //
350 1.0, -2.0, 3.0, -4.0};
351 static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, //
352 2.0, -2.0, 2.0, -2.0};
353 std::vector<float> output(kRow * kBatch);
354 std::fill(output.begin(), output.end(), 3.0);
355 MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
356 output.data());
357 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., //
358 -1., 7., 23.})));
359 }
360
361 // Quantized matmul with 2 * 30 input and 9 * 30 matrix.
TEST(uKernels,QuantMatrixBatchVectorMultiplyAccumulate8x8_16Test)362 TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_16Test) {
363 CpuBackendContext context;
364 const std::vector<int8_t> input = {
365 4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18, 9, -11, -30, 16,
366 -47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
367 11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
368 -44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
369 };
370 const std::vector<int32_t> input_zeropoint_times_weights = {
371 -620, -170, -395, 715, -1220, -1080, 1130, -260, -470,
372 };
373 const std::vector<int8_t> input_to_gate_weights = {
374 -10, -4, -8, 16, 4, -16, -1, 11, 1, 2, -25, 19, 7, 9, 2,
375 -24, -2, 10, -7, 7, -5, -2, 3, 4, 3, -4, -7, -11, -13, -18,
376 11, 10, 12, -9, 17, -15, -5, 20, -6, -11, 2, -6, -18, 15, 4,
377 4, -9, -2, -3, -9, -13, 17, -21, 5, 3, -12, 0, -4, 9, -5,
378 10, -2, 8, 1, -10, -6, 1, -9, 10, 11, -1, -5, 4, -7, -4,
379 -4, 4, 12, -7, -5, -9, -19, 6, -4, 12, -17, -22, 0, 9, -4,
380 -5, 5, -8, 8, 3, 15, -18, -18, 5, 3, -12, 5, -10, 7, 7,
381 -9, 17, 2, -11, -25, 3, 19, -6, 7, 1, 7, 5, -3, 11, 3,
382 0, -8, 8, -2, -2, -12, 14, -5, 7, 8, 16, 20, -16, -5, -5,
383 1, -10, -6, 14, 10, -12, 10, -6, 5, 0, 3, 8, -9, -13, -2,
384 4, 4, -16, -17, -9, 16, -5, 14, -9, -5, -12, 0, 17, 6, -1,
385 16, -20, 1, -11, -1, -10, -21, 13, 4, -12, -7, 0, -14, -6, 3,
386 -4, 6, -18, -3, -1, 14, -8, -6, -15, 5, 12, -3, -10, 4, 6,
387 -5, -20, 0, 3, -3, -7, 1, 2, -10, 7, -3, 6, 1, -12, 6,
388 4, -12, 2, 6, -20, 0, 5, 23, 15, 14, 9, 8, 20, -2, 9,
389 -8, -8, -7, -4, -8, -9, 7, -12, -2, 2, 1, -14, 31, 4, -14,
390 3, 10, -18, -17, -1, 18, 1, 12, 0, 7, -3, -5, 8, -9, 18,
391 17, 7, -15, 3, 20, 4, -8, 16, 6, -3, -3, 9, -4, -6, 4,
392 };
393 const int32_t multiplier = 2080364544;
394 const int32_t shift = -2;
395
396 std::vector<int32_t> scratch(2 * 9, 0);
397 std::vector<int16_t> output = {10, 2, 33, 4, 5, 6, 65, 4, 3,
398 52, 1, 2, 8, -1, -2, 11, 17, -18};
399 MatrixBatchVectorMultiplyAccumulate(
400 input.data(), input_zeropoint_times_weights.data(),
401 input_to_gate_weights.data(), multiplier, shift,
402 /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, /*output_zp=*/0,
403 scratch.data(), output.data(), &context);
404 const std::vector<int16_t> expected_output = {
405 -210, 331, 153, 139, -570, -657, 258, 515, -495,
406 91, -243, -73, 603, -744, -269, 169, -748, -174,
407 };
408
409 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
410 }
411
TEST(uKernels,HybridMatrixBatchVectorMultiplyAccumulate8x8_16Test)412 TEST(uKernels, HybridMatrixBatchVectorMultiplyAccumulate8x8_16Test) {
413 CpuBackendContext context;
414 const std::vector<int8_t> input = {
415 4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18, 9, -11, -30, 16,
416 1, -47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23,
417 12, 1, 11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26,
418 -21, -24, 1, -44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6,
419 -42, -25, 28, 1, 4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18,
420 9, -11, -30, 16, 1, -47, 12, 36, -20, 27, -3, 0, -51, -31, 3,
421 -8, -38, 43, 23, 12, 1, 11, -23, -26, 23, 14, -9, -44, 22, 21,
422 -30, 3, -47, -26, -21, -24, 1, -44, 34, -11, -23, -28, 26, -38, 19,
423 35, 9, 23, 6, -42, -25, 28, 1,
424 };
425 const std::vector<int32_t> input_offsets = {1, 1, 1, 1};
426
427 const std::vector<float> scaling_factors = {
428 1.0,
429 1.0,
430 1.0,
431 1.0,
432 };
433
434 const std::vector<int8_t> input_to_gate_weights = {
435 -10, -4, -8, 16, 4, -16, -1, 11, 1, 2, -25, 19, 7, 9, 2,
436 1, -24, -2, 10, -7, 7, -5, -2, 3, 4, 3, -4, -7, -11, -13,
437 -18, 2, 11, 10, 12, -9, 17, -15, -5, 20, -6, -11, 2, -6, -18,
438 15, 4, 3, 4, -9, -2, -3, -9, -13, 17, -21, 5, 3, -12, 0,
439 -4, 9, -5, 4, 10, -2, 8, 1, -10, -6, 1, -9, 10, 11, -1,
440 -5, 4, -7, -4, 5, -4, 4, 12, -7, -5, -9, -19, 6, -4, 12,
441 -17, -22, 0, 9, -4, 6, -5, 5, -8, 8, 3, 15, -18, -18, 5,
442 3, -12, 5, -10, 7, 7, 7, -9, 17, 2, -11, -25, 3, 19, -6,
443 7, 1, 7, 5, -3, 11, 3, 8, 0, -8, 8, -2, -2, -12, 14,
444 -5, 7, 8, 16, 20, -16, -5, -5, 9, 1, -10, -6, 14, 10, -12,
445 10, -6, 5, 0, 3, 8, -9, -13, -2, 10, 4, 4, -16, -17, -9,
446 16, -5, 14, -9, -5, -12, 0, 17, 6, -1, 11, 16, -20, 1, -11,
447 -1, -10, -21, 13, 4, -12, -7, 0, -14, -6, 3, 12, -4, 6, -18,
448 -3, -1, 14, -8, -6, -15, 5, 12, -3, -10, 4, 6, 13, -5, -20,
449 0, 3, -3, -7, 1, 2, -10, 7, -3, 6, 1, -12, 6, 14, -5,
450 -20, 0, 3, -3, -7, 1, 2, -10, 7, -3, 6, 1, -12, 6, 15,
451 -5, -20, 0, 3, -3, -7, 1, 2, -10, 7, -3, 6, 1, -12, 6,
452 16,
453 };
454
455 std::vector<int32_t> scratch(5 * 8, 0);
456 std::vector<float> output(4 * 8, 0);
457 int32_t* row_sums = scratch.data() + 8 * 4;
458 bool compute_row_sums = true;
459 MatrixBatchVectorMultiplyAccumulate(
460 input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32, input.data(),
461 scaling_factors.data(), /*n_batch*/ 4, output.data(), nullptr,
462 input_offsets.data(), scratch.data(), row_sums, &compute_row_sums,
463 &context);
464
465 const std::vector<float_t> expected_output = {
466 -228, 1548, 937, -166, -1164, -1578, -278, 303, 839, -820, 132,
467 1733, -1858, 58, -425, -587, -228, 1548, 937, -166, -1164, -1578,
468 -278, 303, 839, -820, 132, 1733, -1858, 58, -425, -587,
469 };
470
471 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
472 EXPECT_THAT(compute_row_sums, false);
473
474 std::vector<float> output2(4 * 8, 0);
475 MatrixBatchVectorMultiplyAccumulate(
476 input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32, input.data(),
477 scaling_factors.data(), /*n_batch*/ 4, output2.data(), nullptr,
478 input_offsets.data(), scratch.data(), row_sums, &compute_row_sums,
479 &context);
480
481 EXPECT_THAT(output2, testing::ElementsAreArray(expected_output));
482
483 // Run with a large batch size to trigger the CpuBackendGemm path on any
484 // device.
485 constexpr int kBatchMultiplier = 8;
486 std::vector<int8_t> input_big_batch(input.size() * kBatchMultiplier);
487 std::vector<float> scaling_factors_big_batch(scaling_factors.size() *
488 kBatchMultiplier);
489 std::vector<int32_t> scratch_big_batch(scratch.size() * kBatchMultiplier);
490 std::vector<int32_t> input_offsets_big_batch(input_offsets.size() *
491 kBatchMultiplier);
492 for (int i = 0; i < kBatchMultiplier; i++) {
493 std::copy(input.begin(), input.end(),
494 input_big_batch.begin() + i * input.size());
495 std::copy(scaling_factors.begin(), scaling_factors.end(),
496 scaling_factors_big_batch.begin() + i * scaling_factors.size());
497 std::copy(input_offsets.begin(), input_offsets.end(),
498 input_offsets_big_batch.begin() + i * input_offsets.size());
499 }
500 std::vector<float> output_big_batch(output.size() * kBatchMultiplier, 0);
501 MatrixBatchVectorMultiplyAccumulate(
502 input_to_gate_weights.data(), /*m_rows=*/8, /*m_cols=*/32,
503 input_big_batch.data(), scaling_factors_big_batch.data(),
504 /*n_batch*/ 4 * kBatchMultiplier, output_big_batch.data(), nullptr,
505 input_offsets_big_batch.data(), scratch_big_batch.data(), row_sums,
506 &compute_row_sums, &context);
507 for (int i = 0; i < kBatchMultiplier; i++) {
508 std::vector<float> output_per_batch(
509 output_big_batch.begin() + i * output.size(),
510 output_big_batch.begin() + (i + 1) * output.size());
511 EXPECT_THAT(output_per_batch, testing::ElementsAreArray(expected_output));
512 }
513 }
514
515 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix.
TEST(uKernels,QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test)516 TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test) {
517 CpuBackendContext context;
518 const std::vector<int8_t> input = {
519 4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18, 9, -11, -30, 16,
520 -47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
521 11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
522 -44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
523 };
524 const std::vector<int32_t> input_zeropoint_times_weights = {
525 0, 0, 0, 0, 0, 0, 0, 0, 0,
526 };
527 const std::vector<int8_t> input_to_gate_weights = {
528 13, -7, -20, -22, 8, -46, 9, -2, -18, -42, 40, 28, -7, 24, 34,
529 -7, -24, -24, 19, 14, -19, -6, -2, -3, 5, -36, -13, 6, -27, 36,
530 -23, 0, 20, -37, -23, 9, 17, -41, 33, -15, -18, -42, -41, -34, -16,
531 -6, 12, -14, -15, -20, -14, 21, -3, -1, -26, 54, 51, 35, -14, 9,
532 -2, 13, -6, 39, 34, -21, 39, -51, 19, -44, 52, 0, -2, -38, -35,
533 -33, 4, -22, -37, 27, -23, 3, -10, 5, 32, 6, 1, -35, 24, -19,
534 46, 43, -55, 5, 38, -14, 32, -43, -44, -17, -13, -28, 56, 28, -42,
535 4, 10, -7, 25, -15, -9, -25, -14, -15, 6, -10, -22, 40, -72, 18,
536 -6, -18, -2, 37, -13, -10, 11, -9, 32, -28, 19, -2, 4, -31, 50,
537 -15, 23, -34, -9, 41, -6, -34, 17, 2, 24, -15, 21, -17, -8, -20,
538 1, -63, 19, -40, 12, -5, 5, -6, 1, 19, -9, -23, 5, -34, 11,
539 26, 21, 54, 34, -43, -29, 1, 16, 31, -56, -28, 57, -15, -23, 37,
540 -17, -3, -6, 29, 18, 77, 17, -20, -14, -19, 8, -24, -7, -45, -3,
541 0, -25, -8, 6, 9, 3, -15, 51, 4, -15, -19, -16, -14, -47, -52,
542 25, 9, 58, 26, -9, -27, 49, -6, -21, 21, 18, 12, -9, -9, 14,
543 31, -26, -19, -50, 17, 35, 11, -10, 22, -16, -43, -2, 26, 55, -20,
544 -7, 21, 33, -20, 26, -15, -22, 30, 27, 3, -34, 26, 12, -1, 19,
545 26, -25, 10, 30, 30, -14, -23, -23, -35, -16, 26, -41, 11, 1, 21,
546 };
547 const int32_t multiplier = 1347771520;
548 const int32_t shift = -7;
549 const int32_t output_zp = -11;
550
551 std::vector<int8_t> output = {1, 2, 3, 4, 5, 6, 5, 4, 3,
552 2, 1, 2, 8, -1, -2, 11, 17, 18};
553 std::vector<int32_t> scratch(2 * 9, 0);
554 MatrixBatchVectorMultiplyAccumulate(
555 input.data(), input_zeropoint_times_weights.data(),
556 input_to_gate_weights.data(), multiplier, shift,
557 /*n_batch=*/2, /*n_input=*/30, /*n_output=*/9, output_zp, scratch.data(),
558 output.data(), &context);
559 const std::vector<int8_t> expected_output = {
560 5, -9, -2, -30, -5, -11, -22, -18, 18,
561 -19, 2, 11, -5, 9, -2, 10, -38, -22,
562 };
563
564 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
565 }
566
567 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels,QuantMatrixBatchVectorMultiply8x8_8WithZPTest)568 TEST(uKernels, QuantMatrixBatchVectorMultiply8x8_8WithZPTest) {
569 const int32_t input_zp = 3;
570 const std::vector<int8_t> input = {
571 4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18, 9, -11, -30, 16,
572 -47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
573 11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
574 -44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
575 };
576 const std::vector<int8_t> input_to_gate_weights = {
577 13, -7, -20, -22, 8, -46, 9, -2, -18, -42, 40, 28, -7, 24, 34,
578 -7, -24, -24, 19, 14, -19, -6, -2, -3, 5, -36, -13, 6, -27, 36,
579 -23, 0, 20, -37, -23, 9, 17, -41, 33, -15, -18, -42, -41, -34, -16,
580 -6, 12, -14, -15, -20, -14, 21, -3, -1, -26, 54, 51, 35, -14, 9,
581 -2, 13, -6, 39, 34, -21, 39, -51, 19, -44, 52, 0, -2, -38, -35,
582 -33, 4, -22, -37, 27, -23, 3, -10, 5, 32, 6, 1, -35, 24, -19,
583 46, 43, -55, 5, 38, -14, 32, -43, -44, -17, -13, -28, 56, 28, -42,
584 4, 10, -7, 25, -15, -9, -25, -14, -15, 6, -10, -22, 40, -72, 18,
585 -6, -18, -2, 37, -13, -10, 11, -9, 32, -28, 19, -2, 4, -31, 50,
586 -15, 23, -34, -9, 41, -6, -34, 17, 2, 24, -15, 21, -17, -8, -20,
587 1, -63, 19, -40, 12, -5, 5, -6, 1, 19, -9, -23, 5, -34, 11,
588 26, 21, 54, 34, -43, -29, 1, 16, 31, -56, -28, 57, -15, -23, 37,
589 -17, -3, -6, 29, 18, 77, 17, -20, -14, -19, 8, -24, -7, -45, -3,
590 0, -25, -8, 6, 9, 3, -15, 51, 4, -15, -19, -16, -14, -47, -52,
591 25, 9, 58, 26, -9, -27, 49, -6, -21, 21, 18, 12, -9, -9, 14,
592 31, -26, -19, -50, 17, 35, 11, -10, 22, -16, -43, -2, 26, 55, -20,
593 -7, 21, 33, -20, 26, -15, -22, 30, 27, 3, -34, 26, 12, -1, 19,
594 26, -25, 10, 30, 30, -14, -23, -23, -35, -16, 26, -41, 11, 1, 21,
595 };
596 const int32_t multiplier = 1347771520;
597 const int32_t shift = -7;
598 const int32_t output_zp = -11;
599
600 std::vector<int8_t> output = {1, 2, 3, 4, 5, 6, 5, 4, 3,
601 2, 1, 2, 8, -1, -2, 11, 17, 18};
602
603 MatrixBatchVectorMultiply(
604 input.data(), input_zp, input_to_gate_weights.data(), multiplier, shift,
605 /*n_batch=*/2, /*n_input=*/30, /*n_cell=*/9, output.data(), output_zp);
606 const std::vector<int8_t> expected_output = {6, -9, -4, -32, -10, -17,
607 -25, -25, 14, -19, 3, 10,
608 -12, 10, 0, 1, -57, -41};
609
610 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
611 }
612
613 // Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels,QuantMatrixBatchVectorMultiply16x8_8WithZPTest)614 TEST(uKernels, QuantMatrixBatchVectorMultiply16x8_8WithZPTest) {
615 const std::vector<int16_t> input = {
616 400, -41, 5, -41, 22, 17, -30, 24, 130, -47, 18, 9, -11, -30, 16,
617 -47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
618 11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
619 -44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
620 };
621 const std::vector<int8_t> input_to_gate_weights = {
622 13, -7, -20, -22, 8, -46, 9, -2, -18, -42, 40, 28, -7, 24, 34,
623 -7, -24, -24, 19, 14, -19, -6, -2, -3, 5, -36, -13, 6, -27, 36,
624 -23, 0, 20, -37, -23, 9, 17, -41, 33, -15, -18, -42, -41, -34, -16,
625 -6, 12, -14, -15, -20, -14, 21, -3, -1, -26, 54, 51, 35, -14, 9,
626 -2, 13, -6, 39, 34, -21, 39, -51, 19, -44, 52, 0, -2, -38, -35,
627 -33, 4, -22, -37, 27, -23, 3, -10, 5, 32, 6, 1, -35, 24, -19,
628 46, 43, -55, 5, 38, -14, 32, -43, -44, -17, -13, -28, 56, 28, -42,
629 4, 10, -7, 25, -15, -9, -25, -14, -15, 6, -10, -22, 40, -72, 18,
630 -6, -18, -2, 37, -13, -10, 11, -9, 32, -28, 19, -2, 4, -31, 50,
631 -15, 23, -34, -9, 41, -6, -34, 17, 2, 24, -15, 21, -17, -8, -20,
632 1, -63, 19, -40, 12, -5, 5, -6, 1, 19, -9, -23, 5, -34, 11,
633 26, 21, 54, 34, -43, -29, 1, 16, 31, -56, -28, 57, -15, -23, 37,
634 -17, -3, -6, 29, 18, 77, 17, -20, -14, -19, 8, -24, -7, -45, -3,
635 0, -25, -8, 6, 9, 3, -15, 51, 4, -15, -19, -16, -14, -47, -52,
636 25, 9, 58, 26, -9, -27, 49, -6, -21, 21, 18, 12, -9, -9, 14,
637 31, -26, -19, -50, 17, 35, 11, -10, 22, -16, -43, -2, 26, 55, -20,
638 -7, 21, 33, -20, 26, -15, -22, 30, 27, 3, -34, 26, 12, -1, 19,
639 26, -25, 10, 30, 30, -14, -23, -23, -35, -16, 26, -41, 11, 1, 21,
640 };
641
642 const std::vector<int32_t> input_zeropoint_times_weights = {
643 0, 2, 3, 4, 5, 4, 3, 2, 10,
644 };
645 const int32_t multiplier = 1347771520;
646 const int32_t shift = -8;
647 const int32_t output_zp = -11;
648
649 std::vector<int8_t> output = {1, 2, 3, 4, 5, 6, 5, 4, 3,
650 2, 1, 2, 8, -1, -2, 11, 17, 18};
651
652 MatrixBatchVectorMultiply(
653 input.data(), input_to_gate_weights.data(), multiplier, shift,
654 input_zeropoint_times_weights.data(),
655 /*n_batch=*/2, /*n_hidden=*/30, /*n_output=*/9, output_zp, output.data());
656 const std::vector<int8_t> expected_output = {4, -24, -5, 10, -7, -13,
657 -39, 2, 3, -16, -5, -1,
658 -12, -1, -6, -6, -33, -25};
659
660 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
661 }
662
663 // Quantized matmul with 9 * 30 matrix.
TEST(uKernels,MatrixScalarMultiplyAccumulateTest)664 TEST(uKernels, MatrixScalarMultiplyAccumulateTest) {
665 std::vector<int32_t> output = {
666 -620, -170, -395, 715, -1220, -1080, 1130, -260, -470,
667 };
668 const std::vector<int8_t> weight = {
669 -10, -4, -8, 16, 4, -16, -1, 11, 1, 2, -25, 19, 7, 9, 2,
670 -24, -2, 10, -7, 7, -5, -2, 3, 4, 3, -4, -7, -11, -13, -18,
671 11, 10, 12, -9, 17, -15, -5, 20, -6, -11, 2, -6, -18, 15, 4,
672 4, -9, -2, -3, -9, -13, 17, -21, 5, 3, -12, 0, -4, 9, -5,
673 10, -2, 8, 1, -10, -6, 1, -9, 10, 11, -1, -5, 4, -7, -4,
674 -4, 4, 12, -7, -5, -9, -19, 6, -4, 12, -17, -22, 0, 9, -4,
675 -5, 5, -8, 8, 3, 15, -18, -18, 5, 3, -12, 5, -10, 7, 7,
676 -9, 17, 2, -11, -25, 3, 19, -6, 7, 1, 7, 5, -3, 11, 3,
677 0, -8, 8, -2, -2, -12, 14, -5, 7, 8, 16, 20, -16, -5, -5,
678 1, -10, -6, 14, 10, -12, 10, -6, 5, 0, 3, 8, -9, -13, -2,
679 4, 4, -16, -17, -9, 16, -5, 14, -9, -5, -12, 0, 17, 6, -1,
680 16, -20, 1, -11, -1, -10, -21, 13, 4, -12, -7, 0, -14, -6, 3,
681 -4, 6, -18, -3, -1, 14, -8, -6, -15, 5, 12, -3, -10, 4, 6,
682 -5, -20, 0, 3, -3, -7, 1, 2, -10, 7, -3, 6, 1, -12, 6,
683 4, -12, 2, 6, -20, 0, 5, 23, 15, 14, 9, 8, 20, -2, 9,
684 -8, -8, -7, -4, -8, -9, 7, -12, -2, 2, 1, -14, 31, 4, -14,
685 3, 10, -18, -17, -1, 18, 1, 12, 0, 7, -3, -5, 8, -9, 18,
686 17, 7, -15, 3, 20, 4, -8, 16, 6, -3, -3, 9, -4, -6, 4,
687 };
688 MatrixScalarMultiplyAccumulate(weight.data(), 3, 9, 30, output.data());
689 const std::vector<int32_t> expected_output = {
690 -797, -227, -536, 739, -1187, -1314, 965, -140, -257,
691 };
692
693 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
694 }
695
696 // Quantized layer norm of n_batch = 2 and n_input = 15.
TEST(uKernels,QuantApplyLayerNormTest)697 TEST(uKernels, QuantApplyLayerNormTest) {
698 const std::vector<int16_t> input = {
699 -310, 596, 34, -68, 475, 92, 672, -54, -913, -200,
700 -1194, -836, -620, -237, 991, 533, 721, -736, -8, -941,
701 -372, -1084, 591, 2557, -779, 175, 582, 956, -287, 944,
702 };
703 const std::vector<int16_t> layer_norm_weights = {
704 21849, 22882, 20626, 23854, 24779, 26354, 12980, 26231,
705 23716, 27271, 24937, 22647, 24715, 22854, 19646,
706 };
707 const std::vector<int32_t> bias_weight = {
708 -14175520, -13805465, -16027609, -13786809, -13321033,
709 -14399810, -15055368, -14536623, -14508746, -13784007,
710 -15206609, -15125830, -14996304, -14847597, -12814379,
711 };
712 const int32_t multiplier = 1895840000;
713 const int32_t shift = -13;
714 const int32_t limit = 1;
715
716 std::vector<int16_t> output(2 * 15, 0);
717 ApplyLayerNorm(input.data(), layer_norm_weights.data(), bias_weight.data(),
718 multiplier, shift, limit, 2, 15, output.data());
719 const std::vector<int16_t> expected_output = {
720 -9407, 5846, -4802, -5295, 4822, -2390, 930, -5283,
721 -20352, -7846, -26539, -18704, -15829, -8627, 10313, -2522,
722 -132, -16058, -8206, -19158, -13296, -14407, -1235, 20612,
723 -18591, -6738, -2274, 2602, -11622, 1565,
724 };
725 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
726 }
727
728 // Quantized layer norm of n_batch = 2 and n_input = 15.
TEST(uKernels,QuantApplyLayerNormFloatTest)729 TEST(uKernels, QuantApplyLayerNormFloatTest) {
730 const std::vector<int16_t> input = {
731 -310, 596, 34, -68, 475, 92, 672, -54, -913, -200,
732 -1194, -836, -620, -237, 991, 533, 721, -736, -8, -941,
733 -372, -1084, 591, 2557, -779, 175, 582, 956, -287, 944,
734 };
735 const std::vector<int16_t> layer_norm_weights = {
736 21849, 22882, 20626, 23854, 24779, 26354, 12980, 26231,
737 23716, 27271, 24937, 22647, 24715, 22854, 19646,
738 };
739 const std::vector<int32_t> bias_weight = {
740 -14175520, -13805465, -16027609, -13786809, -13321033,
741 -14399810, -15055368, -14536623, -14508746, -13784007,
742 -15206609, -15125830, -14996304, -14847597, -12814379,
743 };
744 const int32_t multiplier = 1895840000;
745 const int32_t shift = -13;
746
747 std::vector<int16_t> output(2 * 15, 0);
748 ApplyLayerNormFloat(input.data(), layer_norm_weights.data(), multiplier,
749 shift, bias_weight.data(), 2, 15, output.data());
750 const std::vector<int16_t> expected_output = {
751 -9408, 5844, -4803, -5297, 4826, -2392, 927, -5286,
752 -20353, -7851, -26534, -18701, -15830, -8623, 10312, -2524,
753 -136, -16053, -8206, -19160, -13299, -14407, -1233, 20617,
754 -18594, -6736, -2272, 2597, -11620, 1566};
755
756 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
757 }
758
759 // Quantized tanh with Q0.15 input and Q0.15 output.
TEST(uKernels,QuantTanh0Test)760 TEST(uKernels, QuantTanh0Test) {
761 const std::vector<int16_t> input = {
762 -145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
763 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
764 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
765 -145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
766 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
767 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
768 };
769 std::vector<int16_t> output(4 * 15, 0);
770 ApplyTanh(0, input.data(), 4, 15, output.data());
771 const std::vector<int16_t> expected_output = {
772 -136, 904, -176, -40, 260, 292, 8, 28, -44, -1304,
773 -120, 120, -24, 112, 376, -576, -308, 88, -544, 544,
774 652, -32, -60, 1056, -56, -156, -144, -636, 192, -1300,
775 -136, 904, -176, -40, 260, 292, 8, 28, -44, -1304,
776 -120, 120, -24, 112, 376, -576, -308, 88, -544, 544,
777 652, -32, -60, 1056, -56, -156, -144, -636, 192, -1300,
778 };
779 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
780 }
781
782 // Quantized tanh with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantTanh3Test)783 TEST(uKernels, QuantTanh3Test) {
784 const std::vector<int16_t> input = {
785 -145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
786 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
787 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
788 -145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
789 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
790 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
791 };
792 std::vector<int16_t> output(4 * 15, 0);
793 ApplyTanh(3, input.data(), 4, 15, output.data());
794 const std::vector<int16_t> expected_output = {
795 -1156, 7076, -1412, -276, 2104, 2308, 64, 220, -288, -10132,
796 -964, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
797 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
798 -1156, 7076, -1412, -276, 2104, 2308, 64, 220, -288, -10132,
799 -964, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
800 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
801 };
802 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
803 }
804
805 // Quantized tanh with float calculation.
TEST(uKernels,QuantTanhFloatTest)806 TEST(uKernels, QuantTanhFloatTest) {
807 const std::vector<int16_t> input = {
808 -1, 0, 1, -35, 264, 289, 8, 27, -37, -1310,
809 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
810 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
811 -145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
812 -120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
813 653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
814 };
815 std::vector<int16_t> output(4 * 15, 0);
816 ApplyTanhFloat(input.data(), 4, 15, -12, output.data());
817 const std::vector<int16_t> expected_output = {
818 -8, 0, 8, -279, 2109, 2308, 63, 215, -295, -10136,
819 -959, 1015, -127, 847, 2951, -4632, -2387, 743, -4358, 4358,
820 5180, -231, -423, 8280, -415, -1311, -1191, -5039, 1606, -10042,
821 -1159, 7078, -1407, -279, 2109, 2308, 63, 215, -295, -10136,
822 -959, 1015, -127, 847, 2951, -4632, -2387, 743, -4358, 4358,
823 5180, -231, -423, 8280, -415, -1311, -1191, -5039, 1606, -10042};
824
825 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
826 }
827
828 // Quantized tanh with Q4.11 input and Q0.15 output.
TEST(uKernels,QuantTanh4Test)829 TEST(uKernels, QuantTanh4Test) {
830 const std::vector<int16_t> input = {
831 -5, 163, -31, -5, 54, 90, 1, 2, -4, -42, -8, 29, 0, 47, 150,
832 -26, -36, 9, -73, 25, 14, -2, -1, 29, -10, -12, -18, -29, 51, -92,
833 -5, 163, -31, -5, 54, 90, 1, 2, -4, -42, -8, 29, 0, 47, 150,
834 -26, -36, 9, -73, 25, 14, -2, -1, 29, -10, -12, -18, -29, 51, -92,
835 };
836 std::vector<int16_t> output(4 * 15, 0);
837 ApplyTanh(4, input.data(), 4, 15, output.data());
838 const std::vector<int16_t> expected_output = {
839 -76, 2596, -496, -76, 856, 1436, 24, 36, -64, -672,
840 -120, 456, 0, 752, 2400, -412, -576, 148, -1168, 400,
841 216, -36, -24, 456, -164, -192, -292, -456, 820, -1476,
842 -76, 2596, -496, -76, 856, 1436, 24, 36, -64, -672,
843 -120, 456, 0, 752, 2400, -412, -576, 148, -1168, 400,
844 216, -36, -24, 456, -164, -192, -292, -456, 820, -1476,
845 };
846 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
847 }
848
849 // Quantized sigmoid with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantSigmoidTest)850 TEST(uKernels, QuantSigmoidTest) {
851 const std::vector<int16_t> input = {
852 -10500, 1398, -6963, -7404, 485, -5401, -1757, -7668, -19248,
853 -9692, -24249, -17923, -15840, -10026, 5249, -89, 1787, -16178,
854 -6691, -19524, -13439, -24048, -1123, 32767, -17267, -3378, 823,
855 11482, -11139, 7508, -10500, 1398, -6963, -7404, 485, -5401,
856 -1757, -7668, -19248, -9692, -24249, -17923, -15840, -10026, 5249,
857 -89, 1787, -16178, -6691, -19524, -13439, -24048, -1123, 32767,
858 -17267, -3378, 823, 11482, -11139, 7508,
859 };
860 std::vector<int16_t> output(4 * 15, 0);
861 ApplySigmoid(input.data(), 4, 15, output.data());
862 const std::vector<int16_t> expected_output = {
863 2339, 19152, 5063, 4617, 17350, 6917, 12921, 4371, 299, 2813,
864 89, 409, 673, 2605, 25646, 16207, 19904, 615, 5353, 273,
865 1187, 91, 14153, 32756, 475, 9983, 18026, 30898, 2023, 28246,
866 2339, 19152, 5063, 4617, 17350, 6917, 12921, 4371, 299, 2813,
867 89, 409, 673, 2605, 25646, 16207, 19904, 615, 5353, 273,
868 1187, 91, 14153, 32756, 475, 9983, 18026, 30898, 2023, 28246,
869 };
870 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
871 }
872
873 // Quantized sigmoid with Q3.12 input and Q0.15 output.
TEST(uKernels,QuantSigmoidFloatTest)874 TEST(uKernels, QuantSigmoidFloatTest) {
875 const std::vector<int16_t> input = {
876 -10500, 1398, -6963, -7404, 485, -5401, -1757, -7668, -19248,
877 -9692, -24249, -17923, -15840, -10026, 5249, -89, 1787, -16178,
878 -6691, -19524, -13439, -24048, -1123, 32767, -17267, -3378, 823,
879 11482, -11139, 7508, -10500, 1398, -6963, -7404, 485, -5401,
880 -1757, -7668, -19248, -9692, -24249, -17923, -15840, -10026, 5249,
881 -89, 1787, -16178, -6691, -19524, -13439, -24048, -1123, 32767,
882 -17267, -3378, 823, 11482, -11139, 7508,
883 };
884 std::vector<int16_t> output(4 * 15, 0);
885 ApplySigmoidFloat(input.data(), 4, 15, output.data());
886 const std::vector<int16_t> expected_output = {
887 2343, 19153, 5061, 4617, 17352, 6915, 12922, 4368, 295, 2811,
888 87, 407, 671, 2608, 25647, 16206, 19902, 619, 5352, 276,
889 1187, 92, 14151, 32757, 476, 9986, 18024, 30895, 2026, 28249,
890 2343, 19153, 5061, 4617, 17352, 6915, 12922, 4368, 295, 2811,
891 87, 407, 671, 2608, 25647, 16206, 19902, 619, 5352, 276,
892 1187, 92, 14151, 32757, 476, 9986, 18024, 30895, 2026, 28249};
893
894 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
895 }
896
897 // Quantized Multiply with 16bit output and 15 bit shift.
TEST(uKernels,QuantMul16bitOut15ShiftTest)898 TEST(uKernels, QuantMul16bitOut15ShiftTest) {
899 const std::vector<int16_t> input1 = {
900 2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
901 4545, 14835, 1285, 29498, 26788, 2907, 7877, 6331, 8775, 3001,
902 1399, 4683, 1437, 1853, 12163, 4927, 7977, 3001, 16612, 4791,
903 };
904 const std::vector<int16_t> input2 = {
905 -1156, 32767, -32768, -32768, 32767, 2308, 64, 220, -288, -10132,
906 -964, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
907 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
908 };
909 std::vector<int16_t> output(2 * 15, 0);
910 CwiseMul(input1.data(), input2.data(), 2, 15, 15, output.data());
911 const std::vector<int16_t> expected_output = {
912 -88, 32766, -32768, -32767, -32767, 2308, 64, -220, 288, -667,
913 -134, 460, -5, 760, 2407, -412, -575, 142, -1165, 399,
914 221, -33, -19, 468, -153, -197, -291, -462, 817, -1469,
915 };
916 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
917 }
918
919 // Quantized Multiply with 16bit output and 19 bit shift.
TEST(uKernels,QuantMul16bitOut19ShiftTest)920 TEST(uKernels, QuantMul16bitOut19ShiftTest) {
921 const std::vector<int16_t> input1 = {
922 2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
923 4545, 14835, 1285, 29498, 26788, 2907, 7877, 6331, 8775, 3001,
924 1399, 4683, 1437, 1853, 12163, 4927, 7977, 3001, 16612, 4791,
925 };
926 const std::vector<int16_t> input2 = {
927 -1156, 32767, -32768, -32768, 32767, 2308, 64, 220, -288, -10132,
928 -964, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
929 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
930 };
931 std::vector<int16_t> output(2 * 15, 0);
932 CwiseMul(input1.data(), input2.data(), 2, 15, 19, output.data());
933 const std::vector<int16_t> expected_output = {
934 -5, 2048, 2048, -2048, -2048, 144, 4, -14, 18, -42,
935 -8, 29, 0, 47, 150, -26, -36, 9, -73, 25,
936 14, -2, -1, 29, -10, -12, -18, -29, 51, -92,
937 };
938 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
939 }
940
941 // Quantized Multiply with arbitrary scale.
TEST(uKernels,QuantMul8bitArbitrarySclaeTest)942 TEST(uKernels, QuantMul8bitArbitrarySclaeTest) {
943 // scale = 0.000028.
944 int multiplier = 1970324837;
945 int shift = -15;
946
947 const std::vector<int16_t> input1 = {
948 2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 2157,
949 4545, 14835, 1285, 29498, 26788, 2907, 7877, 6331, 8775, 3001,
950 1399, 4683, 1437, 1853, 12163, 4927, 7977, 3001, 16612, 4791,
951 };
952 const std::vector<int16_t> input2 = {
953 -1156, 32767, -32768, -32768, 32767, 2308, 64, 220, -288, -10132,
954 -964, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
955 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
956 };
957 std::vector<int8_t> output(2 * 15, 0);
958 CwiseMul(input1.data(), input2.data(), multiplier, shift, 2, 15, 3,
959 output.data());
960 const std::vector<int8_t> expected_output = {
961 -84, 127, 127, -128, -128, 127, 56, -128, 127, -128,
962 -126, 127, -7, 127, 127, -128, -128, 127, -128, 127,
963 127, -33, -20, 127, -128, -128, -128, -128, 127, -128,
964 };
965 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
966 }
967
968 // Quantized element wise Add with saturation.
TEST(uKernels,QuantAddTest)969 TEST(uKernels, QuantAddTest) {
970 const std::vector<int16_t> input1 = {
971 2491, 32767, -32768, 32767, -32768, 32767, 32767, -32768, -32768, 20000,
972 -20000, 14835, 1285, 29498, 26788, 2907, 7877, 6331, 8775, 3001,
973 1399, 4683, 1437, 1853, 12163, 4927, 7977, 3001, 16612, 4791,
974 };
975 const std::vector<int16_t> input2 = {
976 -1156, 32767, -32768, -32768, 32767, 2308, 64, 220, -288, 20000,
977 -20000, 1016, -120, 844, 2944, -4640, -2392, 736, -4352, 4352,
978 5180, -232, -428, 8276, -412, -1308, -1196, -5044, 1612, -10044,
979 };
980 std::vector<int16_t> output(2 * 15, 0);
981 CwiseAdd(input1.data(), input2.data(), 2, 15, output.data());
982 const std::vector<int16_t> expected_output = {
983 1335, 32767, -32768, -1, -1, 32767, 32767, -32548, -32768, 32767,
984 -32768, 15851, 1165, 30342, 29732, -1733, 5485, 7067, 4423, 7353,
985 6579, 4451, 1009, 10129, 11751, 3619, 6781, -2043, 18224, -5253,
986 };
987 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
988 }
989
TEST(uKernels,ClipTest)990 TEST(uKernels, ClipTest) {
991 constexpr int kVectorSize = 10;
992 constexpr float kAbsLimit = 2.0;
993 std::vector<float> input = {0.0, -0.5, 1.0, -1.5, 2.0,
994 -2.5, 3.0, -3.5, 4.0, -4.5};
995 CwiseClipping(input.data(), kVectorSize, kAbsLimit);
996 const std::vector<float> expected_output = {0.0, -0.5, 1.0, -1.5, 2.0,
997 -2.0, 2.0, -2.0, 2.0, -2.0};
998 EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
999 }
1000
1001 // Quantized clipping for 16 bit.
TEST(uKernels,QuantClip16Test)1002 TEST(uKernels, QuantClip16Test) {
1003 constexpr int kVectorSize = 30;
1004 constexpr int16_t kAbsLimit = 300;
1005 std::vector<int16_t> input = {
1006 -10500, 1, -2, -7404, 200, -5401, -1757, -7668,
1007 -19248, -9692, -24249, -17923, -15840, -10026, 5249, -89,
1008 1787, -200, -6691, -19524, -13439, -24048, -1123, 32767,
1009 -17267, -3378, 823, 11482, -11139, 7508,
1010 };
1011 CwiseClipping(input.data(), kVectorSize, kAbsLimit);
1012 const std::vector<int16_t> expected_output = {
1013 -300, 1, -2, -300, 200, -300, -300, -300, -300, -300,
1014 -300, -300, -300, -300, 300, -89, 300, -200, -300, -300,
1015 -300, -300, -300, 300, -300, -300, 300, 300, -300, 300,
1016 };
1017 EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
1018 }
1019
1020 // Quantized clipping for 8 bit.
TEST(uKernels,QuantClip8Test)1021 TEST(uKernels, QuantClip8Test) {
1022 constexpr int kVectorSize = 30;
1023 constexpr int8_t kAbsLimit = 32;
1024 std::vector<int8_t> input = {
1025 4, -11, -5, -34, -10, -17, -27, -22, 15, 127, -128, 1, 3, 56, 3,
1026 -21, 1, 9, -13, 10, 0, -1, -55, -40, 127, -128, 11, 4, 6, 32,
1027 };
1028 CwiseClipping(input.data(), kVectorSize, kAbsLimit);
1029 const std::vector<int8_t> expected_output = {
1030 4, -11, -5, -32, -10, -17, -27, -22, 15, 32, -32, 1, 3, 32, 3,
1031 -21, 1, 9, -13, 10, 0, -1, -32, -32, 32, -32, 11, 4, 6, 32,
1032 };
1033 EXPECT_THAT(input, testing::ElementsAreArray(expected_output));
1034 }
1035
1036 struct MatrixVectorData {
1037 // Contains dense parameters.
1038 std::vector<int8_t> matrix;
1039
1040 // Like matrix, but with about half of the parameters set to zero.
1041 // Use this to create golden output for sparse matrix tests.
1042 std::vector<int8_t> zeroed_matrix;
1043
1044 // zeroed_matrix described in sparse form.
1045 std::vector<int8_t> sparse_matrix;
1046 std::vector<uint8_t> ledger;
1047
1048 std::vector<int8_t> vectors;
1049 std::vector<float> scale_factors;
1050 std::vector<float> results;
1051
1052 // Per channel scale data.
1053 std::vector<float> per_channel_scales;
1054 std::vector<int32_t> input_offsets;
1055
1056 int rows;
1057 int cols;
1058 int batch;
1059 };
1060
SetupMatrixVectorData(int rows,int cols,int batch,bool negative=false,bool is_per_channel=false,bool init_to_one=false)1061 MatrixVectorData SetupMatrixVectorData(int rows, int cols, int batch,
1062 bool negative = false,
1063 bool is_per_channel = false,
1064 bool init_to_one = false) {
1065 MatrixVectorData data;
1066 data.rows = rows;
1067 data.cols = cols;
1068 data.batch = batch;
1069
1070 for (int i = 0; i < rows * cols; i++) {
1071 int sign = 1;
1072 if ((i % 3) == 0 && negative) sign = -1;
1073 data.matrix.push_back(sign * (i % 70));
1074 }
1075 for (int i = 0; i < cols * batch; i++) {
1076 int sign = 1;
1077 if ((i % 5) == 0 && negative) sign = -1;
1078 data.vectors.push_back(sign * (i % 50));
1079 }
1080 data.scale_factors = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8};
1081 data.results.resize(rows * batch, init_to_one ? 1 : 0);
1082
1083 data.zeroed_matrix = data.matrix;
1084
1085 // Make a sparsification ledger.
1086 for (int i = 0; i < rows; i++) {
1087 int max_chunks = cols / 16;
1088 int selected_chunks = (max_chunks / 2);
1089 bool row_is_odd = (i % 2) > 0;
1090 bool max_chunks_is_odd = (max_chunks % 2) > 0;
1091
1092 data.ledger.push_back(selected_chunks);
1093 if (max_chunks_is_odd && row_is_odd) {
1094 selected_chunks++;
1095 }
1096
1097 // In odd rows, use odd chunk indexes.
1098 // In even rows, use even chunk indexes.
1099 for (int j = 0; j < max_chunks; j++) {
1100 const int chunk_start = i * cols + (j * 16);
1101 const int chunk_end = i * cols + (j * 16) + 16;
1102 if ((j % 2) == (i % 2)) {
1103 // Copy this chunk into the sparse matrix.
1104 data.ledger.push_back(j);
1105 for (int k = chunk_start; k < chunk_end; k++) {
1106 data.sparse_matrix.push_back(data.matrix[k]);
1107 }
1108 } else {
1109 // Zero this part out of zeroed_matrix.
1110 for (int k = chunk_start; k < chunk_end; k++) {
1111 data.zeroed_matrix[k] = 0;
1112 }
1113 }
1114 }
1115 }
1116
1117 if (is_per_channel) {
1118 for (int i = 0; i < rows; i++) {
1119 if (i % 2 == 0) {
1120 data.per_channel_scales.push_back(0.5);
1121 } else {
1122 data.per_channel_scales.push_back(1.0);
1123 }
1124 }
1125
1126 for (int i = 0; i < batch; i++) {
1127 for (int j = 0; j < cols; j++) {
1128 data.vectors[i * cols + j] += i;
1129 }
1130 data.input_offsets.push_back(i);
1131 }
1132 }
1133 return data;
1134 }
1135
TestDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false,bool init_to_one=false)1136 std::vector<float> TestDotprodMatrixBatchVectorMultiply(
1137 int rows, int cols, int batch, bool negative = false,
1138 bool init_to_one = false) {
1139 MatrixVectorData data =
1140 SetupMatrixVectorData(rows, cols, batch, negative, false, init_to_one);
1141
1142 // All partial sums in this computation are small enough to fit in the
1143 // mantissa of a float, and the scale factors are all integers, so we expect
1144 // an exact result.
1145 MatrixBatchVectorMultiplyAccumulate(
1146 data.matrix.data(), rows, cols, data.vectors.data(),
1147 data.scale_factors.data(), batch, &data.results[0]);
1148 return data.results;
1149 }
1150
TestSparseDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false)1151 std::vector<float> TestSparseDotprodMatrixBatchVectorMultiply(
1152 int rows, int cols, int batch, bool negative = false) {
1153 MatrixVectorData data = SetupMatrixVectorData(rows, cols, batch, negative);
1154 SparseMatrixBatchVectorMultiplyAccumulate(
1155 data.sparse_matrix.data(), data.ledger.data(), rows, cols,
1156 data.vectors.data(), data.scale_factors.data(), batch, &data.results[0]);
1157 return data.results;
1158 }
1159
TestPerChannelDotprodMatrixBatchVectorMultiply(int rows,int cols,int batch,bool negative=false,bool is_per_channel=true)1160 std::vector<float> TestPerChannelDotprodMatrixBatchVectorMultiply(
1161 int rows, int cols, int batch, bool negative = false,
1162 bool is_per_channel = true) {
1163 MatrixVectorData data =
1164 SetupMatrixVectorData(rows, cols, batch, negative, is_per_channel);
1165 std::vector<int32_t> scratch(rows * batch);
1166 std::vector<int32_t> row_sums(rows);
1167 bool compute_row_sums = true;
1168 CpuBackendContext context;
1169 MatrixBatchVectorMultiplyAccumulate(
1170 data.matrix.data(), rows, cols, data.vectors.data(),
1171 data.scale_factors.data(), batch, &data.results[0],
1172 data.per_channel_scales.data(), data.input_offsets.data(), scratch.data(),
1173 row_sums.data(), &compute_row_sums, &context);
1174 return data.results;
1175 }
1176
TEST(uKernels,DotprodMatrixBatchVectorMultiplyAccumulateTest)1177 TEST(uKernels, DotprodMatrixBatchVectorMultiplyAccumulateTest) {
1178 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 16, 1),
1179 testing::ElementsAre(1240, 3160, 5080, 7000));
1180
1181 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 2),
1182 testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756,
1183 37416, 60916));
1184
1185 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 3),
1186 testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756,
1187 37416, 60916, 52080, 142704, 55878, 125712));
1188
1189 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(8, 1024, 3),
1190 testing::ElementsAreArray(
1191 {841094, 853168, 866642, 840286, 860760, 862754,
1192 843678, 872552, 1724476, 1769072, 1747588, 1738844,
1193 1758240, 1742916, 1761612, 1755808, 2506896, 2564262,
1194 2629188, 2515824, 2598390, 2569236, 2537352, 2645118}));
1195
1196 const bool kNegative = true;
1197 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 64, 1, kNegative),
1198 testing::ElementsAre(13696, 6904, 7764, 11806));
1199 ASSERT_THAT(
1200 TestDotprodMatrixBatchVectorMultiply(4, 32, 2, kNegative),
1201 testing::ElementsAre(3436, 3522, 1590, 6972, 2516, 20520, 456, 10628));
1202
1203 // Initialize the results vector with 1s to verify that the code adds
1204 // to the results vector instead of zero-ing it first.
1205 const bool kInitToOne = true;
1206 ASSERT_THAT(
1207 TestDotprodMatrixBatchVectorMultiply(4, 32, 2, kNegative, kInitToOne),
1208 testing::ElementsAre(3437, 3523, 1591, 6973, 2517, 20521, 457, 10629));
1209 }
1210
TEST(uKernels,PerChannelDotprodMatrixBatchVectorMultiplyAccumulateTest)1211 TEST(uKernels, PerChannelDotprodMatrixBatchVectorMultiplyAccumulateTest) {
1212 ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 16, 1),
1213 testing::ElementsAre(1240 / 2, 3160, 5080 / 2, 7000));
1214
1215 ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 32, 2),
1216 testing::ElementsAre(10416 / 2, 26288, 8490 / 2, 23312, 18276 / 2,
1217 70756, 37416 / 2, 60916));
1218
1219 ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 32, 3),
1220 testing::ElementsAre(10416 / 2, 26288, 8490 / 2, 23312, 18276 / 2,
1221 70756, 37416 / 2, 60916, 52080 / 2, 142704,
1222 55878 / 2, 125712));
1223
1224 ASSERT_THAT(
1225 TestPerChannelDotprodMatrixBatchVectorMultiply(8, 1024, 3),
1226 testing::ElementsAreArray(
1227 {841094 / 2, 853168, 866642 / 2, 840286, 860760 / 2, 862754,
1228 843678 / 2, 872552, 1724476 / 2, 1769072, 1747588 / 2, 1738844,
1229 1758240 / 2, 1742916, 1761612 / 2, 1755808, 2506896 / 2, 2564262,
1230 2629188 / 2, 2515824, 2598390 / 2, 2569236, 2537352 / 2, 2645118}));
1231 }
1232
TEST(uKernels,DotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest)1233 TEST(uKernels, DotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest) {
1234 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 16, 4),
1235 testing::ElementsAreArray(
1236 {1240, 3160, 6320, 18352, 15240, 45576, 4200, 16232}));
1237 ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 4),
1238 testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626,
1239 165312, 209864, 246128}));
1240 ASSERT_THAT(
1241 TestDotprodMatrixBatchVectorMultiply(2, 64, 8),
1242 testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626, 165312,
1243 209864, 246128, 219700, 195550, 279684, 278928,
1244 413616, 445662, 374896, 365952}));
1245
1246 ASSERT_THAT(
1247 TestDotprodMatrixBatchVectorMultiply(4, 64, 8),
1248 testing::ElementsAreArray(
1249 {45794, 38948, 34622, 32816, 88536, 84252, 85008, 90804,
1250 157626, 165312, 180558, 203364, 209864, 246128, 236472, 208896,
1251 219700, 195550, 184000, 185050, 279684, 278928, 293292, 322776,
1252 413616, 445662, 495348, 513674, 374896, 365952, 321168, 296544}));
1253
1254 ASSERT_THAT(
1255 TestDotprodMatrixBatchVectorMultiply(16, 1024, 4),
1256 testing::ElementsAreArray(
1257 {841094, 853168, 866642, 840286, 860760, 862754, 843678,
1258 872552, 837586, 851270, 877414, 834188, 863062, 857846,
1259 841780, 879054, 1724476, 1769072, 1747588, 1738844, 1758240,
1260 1742916, 1761612, 1755808, 1737684, 1750780, 1747356, 1754152,
1261 1748348, 1753324, 1743320, 1754316, 2506896, 2564262, 2629188,
1262 2515824, 2598390, 2569236, 2537352, 2645118, 2508444, 2571480,
1263 2610576, 2510442, 2618208, 2566584, 2544570, 2614536, 3458904,
1264 3502688, 3474792, 3505976, 3499360, 3488264, 3485848, 3512832,
1265 3500616, 3482520, 3489624, 3469008, 3495992, 3524376, 3465680,
1266 3526264}));
1267
1268 ASSERT_THAT(
1269 TestDotprodMatrixBatchVectorMultiply(4, 128, 4),
1270 testing::ElementsAreArray({87920, 80024, 92288, 103712, 228148, 224820,
1271 233812, 213124, 271284, 271788, 332772, 328236,
1272 419328, 431328, 411968, 417248}));
1273
1274 ASSERT_THAT(
1275 TestDotprodMatrixBatchVectorMultiply(4, 128, 8),
1276 testing::ElementsAreArray(
1277 {87920, 80024, 92288, 103712, 228148, 224820, 233812, 213124,
1278 271284, 271788, 332772, 328236, 419328, 431328, 411968, 417248,
1279 482680, 523840, 560800, 593560, 563940, 609924, 566868, 644772,
1280 743708, 857780, 818972, 823284, 708384, 695008, 730912, 872096}));
1281
1282 const bool kNegative = true;
1283 EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(1, 16, 1, kNegative),
1284 testing::ElementsAre(450));
1285 EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 8, kNegative),
1286 testing::ElementsAreArray({13696, 6904, 9952, 12368, 22848, 61632,
1287 40424, 46776, 57630, 38670, 62976,
1288 49824, 39032, 71988, 60128, 148992}));
1289
1290 std::vector<float> results =
1291 TestDotprodMatrixBatchVectorMultiply(256, 1024, 8);
1292 int64_t sum = 0;
1293 for (int i = 0; i < results.size(); i++) {
1294 sum += static_cast<int64_t>(results[i]);
1295 }
1296 EXPECT_EQ(7980076336, sum);
1297 }
1298
TEST(uKernels,PerChannelDotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest)1299 TEST(uKernels,
1300 PerChannelDotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest) {
1301 ASSERT_THAT(
1302 TestPerChannelDotprodMatrixBatchVectorMultiply(16, 1024, 4),
1303 testing::ElementsAreArray(
1304 {841094 / 2, 853168, 866642 / 2, 840286, 860760 / 2, 862754,
1305 843678 / 2, 872552, 837586 / 2, 851270, 877414 / 2, 834188,
1306 863062 / 2, 857846, 841780 / 2, 879054, 1724476 / 2, 1769072,
1307 1747588 / 2, 1738844, 1758240 / 2, 1742916, 1761612 / 2, 1755808,
1308 1737684 / 2, 1750780, 1747356 / 2, 1754152, 1748348 / 2, 1753324,
1309 1743320 / 2, 1754316, 2506896 / 2, 2564262, 2629188 / 2, 2515824,
1310 2598390 / 2, 2569236, 2537352 / 2, 2645118, 2508444 / 2, 2571480,
1311 2610576 / 2, 2510442, 2618208 / 2, 2566584, 2544570 / 2, 2614536,
1312 3458904 / 2, 3502688, 3474792 / 2, 3505976, 3499360 / 2, 3488264,
1313 3485848 / 2, 3512832, 3500616 / 2, 3482520, 3489624 / 2, 3469008,
1314 3495992 / 2, 3524376, 3465680 / 2, 3526264}));
1315
1316 ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 128, 4),
1317 testing::ElementsAreArray(
1318 {87920 / 2, 80024, 92288 / 2, 103712, 228148 / 2, 224820,
1319 233812 / 2, 213124, 271284 / 2, 271788, 332772 / 2, 328236,
1320 419328 / 2, 431328, 411968 / 2, 417248}));
1321
1322 ASSERT_THAT(TestPerChannelDotprodMatrixBatchVectorMultiply(4, 128, 8),
1323 testing::ElementsAreArray(
1324 {87920 / 2, 80024, 92288 / 2, 103712, 228148 / 2, 224820,
1325 233812 / 2, 213124, 271284 / 2, 271788, 332772 / 2, 328236,
1326 419328 / 2, 431328, 411968 / 2, 417248, 482680 / 2, 523840,
1327 560800 / 2, 593560, 563940 / 2, 609924, 566868 / 2, 644772,
1328 743708 / 2, 857780, 818972 / 2, 823284, 708384 / 2, 695008,
1329 730912 / 2, 872096}));
1330 }
1331
TEST(uKernels,DotprodSparseMatrixBatchVectorMultiplyAccumulate)1332 TEST(uKernels, DotprodSparseMatrixBatchVectorMultiplyAccumulate) {
1333 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 16, 1),
1334 testing::ElementsAre(0));
1335 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 32, 1),
1336 testing::ElementsAre(1240));
1337 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1),
1338 testing::ElementsAre(26544));
1339 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 2),
1340 testing::ElementsAre(26544, 24344));
1341 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(4, 64, 4),
1342 testing::ElementsAreArray(
1343 {26544, 15866, 22140, 11408, 24344, 53248, 42704, 39900,
1344 48000, 94146, 101892, 81876, 87712, 105160, 148304, 75936}));
1345
1346 const bool kNegative = true;
1347 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1, kNegative),
1348 testing::ElementsAre(8764));
1349 EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(2, 64, 2, kNegative),
1350 testing::ElementsAre(8764, 5196, 7204, 11148));
1351 }
1352
1353 #ifdef __ANDROID__
TEST(uKernels,MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest)1354 TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
1355 // Note we use 29 columns as this exercises all the neon kernel: the
1356 // 16-block SIMD code, the 8-block postamble, and the leftover postamble.
1357 const int a_rows = 4, a_cols = 29;
1358 const int kWeightsPerUint32 = 4;
1359 /* clang-format off */
1360 const float a_float_data[] = {
1361 /* 1st row */
1362 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1363 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
1364 24.24, 25.25, 26.26, 27.27, 28.28, 0,
1365 /* 2nd row */
1366 -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1367 -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
1368 -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
1369 /* 3rd row */
1370 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
1371 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
1372 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
1373 /* 4th row */
1374 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1375 -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
1376 -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
1377
1378 int8_t* a_int8_data = reinterpret_cast<int8_t*>(
1379 aligned_malloc(a_rows * a_cols, kWeightsPerUint32));
1380 float a_min, a_max;
1381 float scaling_factor_a;
1382 SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min,
1383 &a_max, &scaling_factor_a);
1384 const int8_t expected_a_int8_data[] = {
1385 /* 1st row */
1386 5, 10, 15, 20, 25, 30, 35, 40, 44, 45, 50, 54, 59, 64, 68, 73, 77, 82, 86,
1387 91, 95, 100, 104, 109, 113, 118, 122, 127, 0,
1388 /* 2nd row */
1389 -5, -10, -15, -20, -25, -30, -35, -40, -44, -45, -50, -54, -59, -64, -68,
1390 -73, -77, -82, -86, -91, -95, -100, -104, -109, -113, -118, -122, -127, 0,
1391 /* 3rd row */
1392 5, -10, 15, -20, 25, -30, 35, -40, 44, -45, 50, -54, 59, -64, 68, -73, 77,
1393 -82, 86, -91, 95, -100, 104, -109, 113, -118, 122, -127, 0,
1394 /* 4th row */
1395 -5, 10, -15, 20, -25, 30, -35, 40, -44, 45, -50, 54, -59, 64, -68, 73, -77,
1396 82, -86, 91, -95, 100, -104, 109, -113, 118, -122, 127, 0,
1397 };
1398 for (int i = 0; i < a_rows * a_cols; ++i) {
1399 EXPECT_EQ(expected_a_int8_data[i], a_int8_data[i]);
1400 }
1401
1402 const int b_rows = 29, b_cols = 1, batches = 2;
1403 const float b_float_data[] = {
1404 /* batch 1 */
1405 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1406 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1407 1.0,
1408 /* batch 2 */
1409 2.5, -2.1, 3.0, -1.3, 1.3, -1.1, 2.0, -1.7, 1.9, -1.5, 0.5, -0.7, 0.8, -0.3,
1410 2.8, -2.8, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5, 0.7, -1.9,
1411 0.2,
1412 };
1413
1414 // Quantized values of B:
1415 int8_t b_int8_data[b_rows * b_cols * batches];
1416 float b_min, b_max;
1417 float scaling_factor_b[batches];
1418 SymmetricQuantizeFloats(b_float_data, b_rows * b_cols, b_int8_data, &b_min,
1419 &b_max, &scaling_factor_b[0]);
1420 SymmetricQuantizeFloats(&b_float_data[b_rows * b_cols], b_rows * b_cols,
1421 &b_int8_data[b_rows * b_cols], &b_min, &b_max,
1422 &scaling_factor_b[1]);
1423
1424 const int8_t expected_b_int8_data[] = {
1425 /* batch 1 */
1426 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1427 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1428 127,
1429 /* batch 2 */
1430 106, -89, 127, -55, 55, -47, 85, -72, 80, -64, 21, -30, 34, -13, 119, -119,
1431 47, -97, 80, -80, 89, -21, 102, -4, 42, -106, 30, -80, 8,
1432 };
1433 /* clang-format on */
1434 for (int i = 0; i < b_rows * b_cols * batches; ++i) {
1435 EXPECT_EQ(expected_b_int8_data[i], b_int8_data[i]);
1436 }
1437
1438 // Full float operation results in:
1439 // -13.69, 13.69, 414.11, -414.11
1440 // -6.325, 6.325, 631.263, -631.263
1441 float c_float_data[a_rows * b_cols * batches];
1442 for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1443 c_float_data[i] = 0.0;
1444 }
1445
1446 // Testing product.
1447 const float scaling_factor_c[2] = {
1448 scaling_factor_a * scaling_factor_b[0],
1449 scaling_factor_a * scaling_factor_b[1],
1450 };
1451 MatrixBatchVectorMultiplyAccumulate(a_int8_data, a_rows, a_cols, b_int8_data,
1452 scaling_factor_c, batches, c_float_data);
1453
1454 // Assert we obtain the expected recovered float values.
1455 const float expected_c_float_data[] = {
1456 -14.474, 14.474, 414.402, -414.402, -6.92228, 6.92228, 632.042, -632.042,
1457 };
1458 for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1459 EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001);
1460 }
1461
1462 // Call version of MatrixBatchVectorMultiplyAccumulate that uses
1463 // CpuBackendGemm.
1464 std::vector<int32_t> accum_scratch(a_rows * batches);
1465 std::vector<float> c_float_data_2(a_rows * batches, 0.0);
1466 CpuBackendContext context;
1467 MatrixBatchVectorMultiplyAccumulate(
1468 a_int8_data, a_rows, a_cols, b_int8_data, scaling_factor_c, batches,
1469 accum_scratch.data(), c_float_data_2.data(), &context);
1470
1471 // Assert (again) we obtain the expected recovered float values.
1472 for (int i = 0; i < a_rows * b_cols * batches; ++i) {
1473 EXPECT_NEAR(expected_c_float_data[i], c_float_data_2[i], 0.001);
1474 }
1475
1476 aligned_free(a_int8_data);
1477 }
1478 #endif // __ANDROID__
1479
TEST(uKernels,SparseMatrixBatchVectorMultiplyAccumulateTest)1480 TEST(uKernels, SparseMatrixBatchVectorMultiplyAccumulateTest) {
1481 const int kRow = 4;
1482 const int kCol = 48;
1483 const int kBatch = 2;
1484 /* clang-format off */
1485 float matrix[kRow * kCol] = {
1486 /* 1st row */
1487 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1488 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1489 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
1490 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0,
1491 /* 2nd row */
1492 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1493 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
1494 -25.25, -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1495 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0,
1496 /* 3rd row */
1497 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1498 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
1499 -26.26, 27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1500 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0,
1501 /* 4th row */
1502 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1503 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1504 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
1505 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0};
1506
1507 // BCSR format of the above matrix.
1508 float matrix_values[] = {
1509 /* 1st row */
1510 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1511 14.14, 15.15, 16.16, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, 39.39,
1512 40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0,
1513 /* 2nd row */
1514 -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, -25.25,
1515 -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0,
1516 /* 3rd row */
1517 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, -26.26,
1518 27.27, -28.28, 0, 0.0, 0.0, 0.0,
1519 /* 4th row */
1520 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1521 -13.13, 14.14, -15.15, 16.16, -33.33, 34.34, -35.35, 36.36, -37.37, 38.38,
1522 -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0};
1523 uint8_t ledger[] = {
1524 2, 0, 2, // 1st row
1525 1, 1, // 2nd row
1526 1, 1, // 3rd row
1527 2, 0, 2 // 4th row
1528 };
1529
1530 float vector[kBatch * kCol] = {
1531 /* 1st batch */
1532 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1533 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1534 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1535 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1536 /* 2nd batch */
1537 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
1538 -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0,
1539 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5,
1540 0.7, -1.9, 0.2, 0.0, 0.1, 0.2,
1541 };
1542 /* clang-format on */
1543
1544 std::vector<float> dense_output(kRow * kBatch, 0.0);
1545 MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
1546 dense_output.data());
1547
1548 EXPECT_THAT(dense_output, ElementsAreArray(ArrayFloatNear(
1549 {-13.69, 6.06001, 272.7, -608.03, -9.66602,
1550 -10.201, 10.201, -713.897949},
1551 1e-4)));
1552
1553 std::vector<float> sparse_output(kRow * kBatch, 0.0);
1554 SparseMatrixBatchVectorMultiplyAccumulate(
1555 matrix_values, ledger, kRow, kCol, vector, kBatch, sparse_output.data());
1556
1557 EXPECT_THAT(sparse_output,
1558 ElementsAreArray(ArrayFloatNear(dense_output, 1e-4)));
1559 }
1560
1561 #ifdef __ANDROID__
TEST(uKernels,SparseMatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest)1562 TEST(uKernels,
1563 SparseMatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
1564 const int kRow = 4;
1565 const int kCol = 48;
1566 const int kBatch = 2;
1567 /* clang-format off */
1568 const int8_t quantized_matrix[] = {
1569 /* 1st row */
1570 3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 0, 0, 0, 0,
1571 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, 98, 101, 104, 107, 110, 113, 115,
1572 118, 121, 124, 127, 0, 0, 0, 0,
1573 /* 2nd row */
1574 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -49, -52, -55, -58, -61,
1575 -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1576 0, 0, 0, 0, 0, 0, 0,
1577 /* 3rd row */
1578 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, -52, 55, -58, 61, -64,
1579 66, -69, 72, -75, 78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1580 0, 0, 0, 0,
1581 /* 4th row */
1582 -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, 0, 0,
1583 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -95, 98, -101, 104, -107, 110,
1584 -113, 115, -118, 121, -124, 127, 0, 0, 0, 0,
1585 };
1586 const int8_t quantized_matrix_values[] = {
1587 /* 1st row */
1588 3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 95, 98, 101,
1589 104, 107, 110, 113, 115, 118, 121, 124, 127, 0, 0, 0, 0,
1590 /* 2nd row */
1591 -49, -52, -55, -58, -61, -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0,
1592 /* 3rd row */
1593 49, -52, 55, -58, 61, -64, 66, -69, 72, -75, 78, -81, 0, 0, 0, 0,
1594 /* 4th row */
1595 -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, -95,
1596 98, -101, 104, -107, 110, -113, 115, -118, 121, -124, 127, 0, 0, 0, 0,
1597 };
1598 uint8_t ledger[] = {
1599 2, 0, 2, // 1st row
1600 1, 1, // 2nd row
1601 1, 1, // 3rd row
1602 2, 0, 2 // 4th row
1603 };
1604
1605 float matrix_scaling_factor = 0.349921;
1606
1607 const int8_t quantized_vector[] = {
1608 /* 1st batch */
1609 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127,
1610 -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127,
1611 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127,
1612 -127, 127, -127, 127, -127, 127, -127, 127, -127,
1613 /* 2nd batch */
1614 106, 0, -89, 0, 127, 0, -55, 0, 55, 0, -47, 0, 85, 0, -72, 0, 80, 0,
1615 -64, 0, 21, 0, -30, 0, 34, 0, -13, 0, 119, 0, -119, 0, 47, -97, 80, -80,
1616 89, -21, 102, -4, 42, -106, 30, -80, 8, 1, 2, 3,
1617 };
1618 float vector_scaling_factor[2] = {0.00787402, 0.023622};
1619
1620 /* clang-format on */
1621 float result_scaling_factor[2] = {
1622 matrix_scaling_factor * vector_scaling_factor[0],
1623 matrix_scaling_factor * vector_scaling_factor[1],
1624 };
1625 std::vector<float> dense_output(kRow * kBatch, 0.0);
1626 MatrixBatchVectorMultiplyAccumulate(quantized_matrix, kRow, kCol,
1627 quantized_vector, result_scaling_factor,
1628 kBatch, dense_output.data());
1629
1630 EXPECT_THAT(dense_output,
1631 ElementsAreArray(ArrayFloatNear(
1632 {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464,
1633 -9.381721, 9.381721, -713.845642})));
1634
1635 std::vector<float> sparse_output(kRow * kBatch, 0.0);
1636 SparseMatrixBatchVectorMultiplyAccumulate(
1637 quantized_matrix_values, ledger, kRow, kCol, quantized_vector,
1638 result_scaling_factor, kBatch, sparse_output.data());
1639
1640 EXPECT_THAT(sparse_output,
1641 ElementsAreArray(ArrayFloatNear(
1642 {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464,
1643 -9.381721, 9.381721, -713.845642})));
1644 }
1645 #endif // __ANDROID__
1646
TEST(uKernels,VectorVectorCwiseProductTest)1647 TEST(uKernels, VectorVectorCwiseProductTest) {
1648 constexpr int kVectorSize = 10;
1649 static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
1650 -2.5, 3.0, -3.5, 4.0, -4.5};
1651 static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
1652 -0.1, 0.1, -0.1, 0.1, -0.1};
1653 std::vector<float> output(kVectorSize);
1654 VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data());
1655 EXPECT_THAT(output,
1656 ElementsAreArray(ArrayFloatNear(
1657 {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45})));
1658 }
1659
TEST(uKernels,VectorVectorCwiseProductAccumulateTest)1660 TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
1661 constexpr int kVectorSize = 10;
1662 static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
1663 -2.5, 3.0, -3.5, 4.0, -4.5};
1664 static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
1665 -0.1, 0.1, -0.1, 0.1, -0.1};
1666 std::vector<float> output(kVectorSize);
1667 std::fill(output.begin(), output.end(), 1.0);
1668 VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize,
1669 output.data());
1670 EXPECT_THAT(output,
1671 ElementsAreArray(ArrayFloatNear(
1672 {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
1673 }
1674
TEST(uKernels,VectorBatchVectorAddTest)1675 TEST(uKernels, VectorBatchVectorAddTest) {
1676 constexpr int kVectorSize = 3;
1677 constexpr int kBatchSize = 2;
1678 static float input[kVectorSize] = {0.0, -0.5, 1.0};
1679 std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
1680 VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
1681 EXPECT_THAT(output,
1682 testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
1683 }
1684
TEST(uKernels,VectorBatchVectorAssignTest)1685 TEST(uKernels, VectorBatchVectorAssignTest) {
1686 constexpr int kVectorSize = 5;
1687 constexpr int kBatchSize = 3;
1688 static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1689 std::vector<float> output(kVectorSize * kBatchSize);
1690 VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data());
1691 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1692 {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0,
1693 0.0, -0.5, 1.0, -1.5, 2.0})));
1694 }
1695
TEST(uKernels,ApplySigmoidToVectorTest)1696 TEST(uKernels, ApplySigmoidToVectorTest) {
1697 constexpr int kVectorSize = 5;
1698 static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1699 std::vector<float> output(kVectorSize);
1700 ApplySigmoidToVector(input, kVectorSize, output.data());
1701 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1702 {0.5, 0.377541, 0.731059, 0.182426, 0.880797})));
1703 }
1704
TEST(uKernels,ApplyActivationToVectorTest)1705 TEST(uKernels, ApplyActivationToVectorTest) {
1706 constexpr int kVectorSize = 5;
1707 static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1708 std::vector<float> output(kVectorSize);
1709 ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data());
1710 EXPECT_THAT(output,
1711 ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0})));
1712
1713 ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data());
1714 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
1715 {0.0, -0.462117, 0.761594, -0.905148, 0.964028})));
1716 }
1717
TEST(uKernels,Sub1VectorTest)1718 TEST(uKernels, Sub1VectorTest) {
1719 constexpr int kVectorSize = 5;
1720 static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
1721 std::vector<float> output(kVectorSize);
1722 Sub1Vector(input, kVectorSize, output.data());
1723 EXPECT_THAT(output,
1724 ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0})));
1725 }
1726
TEST(uKernels,Sub1VectorInt16Test)1727 TEST(uKernels, Sub1VectorInt16Test) {
1728 constexpr int kVectorSize = 30;
1729 static int16_t input[kVectorSize] = {
1730 32760, 300, 1, 2, 3, 4, 5, 6, 300, 1000,
1731 32767, 32000, 300, 1, 2, 3, 4, 5, 56, 300,
1732 1000, 32767, 32761, 1300, 1, 2, 3, 4, 5, 6,
1733 };
1734 std::vector<int16_t> output(kVectorSize);
1735 Sub1Vector(input, kVectorSize, output.data());
1736 EXPECT_THAT(
1737 output,
1738 testing::ElementsAreArray({
1739 7, 32467, 32766, 32765, 32764, 32763, 32762, 32761, 32467, 31767,
1740 0, 767, 32467, 32766, 32765, 32764, 32763, 32762, 32711, 32467,
1741 31767, 0, 6, 31467, 32766, 32765, 32764, 32763, 32762, 32761,
1742 }));
1743 }
1744
TEST(uKernels,VectorBatchVectorCwiseProductAccumulateInteger)1745 TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) {
1746 constexpr int kVectorSize = 29;
1747 constexpr int kBatchSize = 4;
1748 static int16_t vector[kVectorSize] = {-10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
1749 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
1750 10, 11, 12, 13, 14, 15, 16, 17, 18};
1751 const std::vector<int16_t> batch_vector = {
1752 /* batch 0 */
1753 10, 11, 12, 13, 14, 15, 16, 17, 18, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1,
1754 2, 3, 4, 5, 6, 7, 8, 9,
1755 /* batch 1 */
1756 -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0, 1,
1757 2, 3, 4, 5, 6, 7, 8, 9,
1758 /* batch 2 */
1759 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1760 13, 14, 15, 16, 17, 18,
1761 /* batch 3 */
1762 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1763 13, 14, 15, 16, 17, 18};
1764 std::vector<int16_t> batch_output = {
1765 /* batch 0 */
1766 -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0, 1,
1767 2, 3, 4, 5, 6, 7, 8, 9,
1768 /* batch 1 */
1769 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5,
1770 4, 3, 2, 1, 10, 11, 12,
1771 /* batch 2 */
1772 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10, 11, 12,
1773 13, 14, 15, 16, 17, 18,
1774 /* batch 3 */
1775 10, 11, 12, 13, 14, 15, 16, 17, 18, -10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1,
1776 13, 14, 15, 16, 17, 18};
1777 // Test with 0.25 scale, which is decomposed into (1073741824, -1).
1778 VectorBatchVectorCwiseProductAccumulate(vector, kVectorSize,
1779 batch_vector.data(), kBatchSize,
1780 1073741824, -1, batch_output.data());
1781
1782 const std::vector<int16_t> expected_output = {
1783 /* batch 0 */
1784 -35, 34, 32, 30, 27, 24, 20, 16, 11, -1, 10, 13, 16, 18, 19, 20, 21, 21,
1785 20, 0, 4, 8, 12, 17, 23, 29, 35, 42, 50,
1786 /* batch 1 */
1787 27, 24, 20, 18, 15, 14, 12, 12, 1, 2, 2, 6, 10, 15, 20, 26, 32, 39, 26, 9,
1788 11, 13, 15, 18, 22, 26, 30, 35, 51,
1789 /* batch 2 */
1790 11, 15, 4, 7, 8, 10, 10, 11, 10, 10, 8, 12, -6, 15, 14, 14, 12, 11, 8, 6,
1791 27, 32, 46, 54, 61, 70, 78, 88, 97,
1792 /* batch 3 */
1793 17, 21, 14, 17, 18, 20, 20, 21, 20, 20, 18, -7, 13, 14, 13, 13, 11, 10, 7,
1794 5, 26, 31, 37, 56, 63, 72, 80, 90, 99};
1795 // Only allow 1 element difference for the rounding result.
1796 CompareRoundingResults<int16_t>(4 * 29, expected_output.data(),
1797 batch_output.data(), 1, 1);
1798 }
1799
TEST(uKernels,VectorBatchVectorCwiseProductAccumulateFloat)1800 TEST(uKernels, VectorBatchVectorCwiseProductAccumulateFloat) {
1801 constexpr int kVectorSize = 29;
1802 constexpr int kBatchSize = 4;
1803 static float input[kVectorSize] = {
1804 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f,
1805 9.9f, 10.10f, 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f,
1806 17.17f, 18.18f, 19.19f, 20.20f, 21.21f, 22.22f, 23.23f, 24.24f,
1807 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
1808 std::vector<float> output = {
1809 /* batch 0 */
1810 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.10f, 11.11f,
1811 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.20f,
1812 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
1813 /* batch 1 */
1814 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.10f,
1815 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f,
1816 -19.19f, -20.20f, -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f,
1817 -27.27f, -28.28f, 0.0f,
1818 /* batch 2 */
1819 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.10f, 11.11f,
1820 -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f,
1821 -20.20f, 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f,
1822 -28.28f, 0.0f,
1823 /* batch 3 */
1824 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.10f,
1825 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f,
1826 -19.19f, 20.20f, -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f,
1827 -27.27f, 28.28f, 0.0f};
1828 VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
1829 kBatchSize, output.data());
1830
1831 // Expect output = input * output + output.
1832 const std::vector<float> expected_output = {
1833 /* batch 0 */
1834 2.31f, 7.04f, 14.19f, 23.76f, 35.75f, 50.16f, 66.99f, 86.24f, 107.91f,
1835 112.11f, 134.5421f, 159.0144f, 185.5269f, 214.0796f, 244.6725f, 277.3056f,
1836 311.9789f, 348.6924f, 387.4461f, 428.24f, 471.0741f, 515.9484f, 562.8629f,
1837 611.8176f, 662.8125f, 715.8476f, 770.9229f, 828.0384f, 0.0f,
1838 /* batch 1 */
1839 -2.31f, -7.04f, -14.19f, -23.76f, -35.75f, -50.16f, -66.99f, -86.24f,
1840 -107.91f, -112.11f, -134.5421f, -159.0144f, -185.5269f, -214.0796f,
1841 -244.6725f, -277.3056f, -311.9789f, -348.6924f, -387.4461f, -428.24f,
1842 -471.0741f, -515.9484f, -562.8629f, -611.8176f, -662.8125f, -715.8476f,
1843 -770.9229f, -828.0384f, 0.0f,
1844 /* batch 2 */
1845 2.31f, -7.04f, 14.19f, -23.76f, 35.75f, -50.16f, 66.99f, -86.24f, 107.91f,
1846 -112.11f, 134.5421f, -159.0144f, 185.5269f, -214.0796f, 244.6725f,
1847 -277.3056f, 311.9789f, -348.6924f, 387.4461f, -428.24f, 471.0741f,
1848 -515.9484f, 562.8629f, -611.8176f, 662.8125f, -715.8476f, 770.9229f,
1849 -828.0384f, 0.0f,
1850 /* batch 3 */
1851 -2.31f, 7.04f, -14.19f, 23.76f, -35.75f, 50.16f, -66.99f, 86.24f,
1852 -107.91f, 112.11f, -134.5421f, 159.0144f, -185.5269f, 214.0796f,
1853 -244.6725f, 277.3056f, -311.9789f, 348.6924f, -387.4461f, 428.24f,
1854 -471.0741f, 515.9484f, -562.8629f, 611.8176f, -662.8125f, 715.8476f,
1855 -770.9229f, 828.0384f, 0.0f};
1856 EXPECT_THAT(output, testing::ElementsAreArray(
1857 ArrayFloatNear(expected_output, 6.5e-5f)));
1858 }
1859
TEST(uKernels,VectorBatchVectorCwiseProductNoAccumulate)1860 TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
1861 constexpr int kVectorSize = 29;
1862 constexpr int kBatchSize = 4;
1863 static float input[kVectorSize] = {
1864 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
1865 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
1866 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
1867 std::vector<float> output = {
1868 /* batch 0 */
1869 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1870 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
1871 24.24, 25.25, 26.26, 27.27, 28.28, 0,
1872 /* batch 1 */
1873 -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1874 -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
1875 -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
1876 /* batch 2 */
1877 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
1878 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
1879 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
1880 /* batch 3 */
1881 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1882 -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
1883 -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
1884 VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
1885 output.data());
1886
1887 // Expect output = input * output + output.
1888 const std::vector<float> expected_output = {
1889 /* batch 0 */
1890 1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
1891 77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
1892 199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
1893 408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
1894 689.587585, 743.652954, 799.758423, 0.000000,
1895 /* batch 1 */
1896 -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
1897 -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
1898 -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
1899 -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
1900 -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
1901 -799.758423, 0.000000,
1902 /* batch 2 */
1903 1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
1904 59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
1905 172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
1906 368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
1907 637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
1908 /* batch 3 */
1909 -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
1910 -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
1911 -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
1912 -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
1913 -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
1914 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
1915 }
1916
TEST(uKernels,BatchVectorBatchVectorDotProductTest)1917 TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
1918 constexpr int kVectorSize = 5;
1919 constexpr int kBatch = 2;
1920 static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0,
1921 -2.5, 3.0, -3.5, 4.0, -4.5};
1922 static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1,
1923 -0.1, 0.1, -0.1, 0.1, -0.1};
1924 std::vector<float> output(kBatch);
1925 BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
1926 output.data());
1927 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75})));
1928 }
1929
TEST(uKernels,BatchVectorBatchVectorDotProductIntegerTest)1930 TEST(uKernels, BatchVectorBatchVectorDotProductIntegerTest) {
1931 constexpr int kVectorSize = 5;
1932 constexpr int kBatch = 2;
1933 static int16_t input1[kVectorSize * kBatch] = {0, 5, 10, -15, 20,
1934 -25, 30, -35, 40, -45};
1935 static int16_t input2[kVectorSize * kBatch] = {1, -1, 1, -1, 1,
1936 -1, 1, -1, 1, 1};
1937 std::vector<int32_t> output(kBatch);
1938 BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
1939 output.data());
1940 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({40, 85})));
1941 }
1942
TEST(uKernels,ReductionSumVectorTest)1943 TEST(uKernels, ReductionSumVectorTest) {
1944 constexpr int kInputVectorSize = 10;
1945 constexpr int kOutputVectorSize1 = 5;
1946 constexpr int kReductionSize1 = 2;
1947 static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
1948 0.0, -0.5, 1.0, 1.0, 2.0};
1949 std::vector<float> result1(kOutputVectorSize1);
1950 ReductionSumVector(input, result1.data(), kOutputVectorSize1,
1951 kReductionSize1);
1952 EXPECT_THAT(result1,
1953 ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0})));
1954
1955 constexpr int kOutputVectorSize2 = 2;
1956 constexpr int kReductionSize2 = 5;
1957 std::vector<float> result2(kOutputVectorSize2);
1958 ReductionSumVector(input, result2.data(), kOutputVectorSize2,
1959 kReductionSize2);
1960 EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
1961 }
1962
TEST(uKernels,ReductionSumVectorIntegerTest)1963 TEST(uKernels, ReductionSumVectorIntegerTest) {
1964 constexpr int kInputVectorSize = 10;
1965 constexpr int kOutputVectorSize1 = 5;
1966 constexpr int kReductionSize1 = 2;
1967 static int32_t input[kInputVectorSize] = {1, 2, 1, 5, -3, 2, 1, 2, 5, 10};
1968 std::vector<int32_t> result1(kOutputVectorSize1);
1969 ReductionSumVector(input, result1.data(), kOutputVectorSize1,
1970 kReductionSize1);
1971 EXPECT_THAT(result1, testing::ElementsAreArray({3, 6, -1, 3, 15}));
1972 }
1973
1974 void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
1975 const int8_t* recurrent, int8_t recurrent_zp,
1976 int32_t input_effective_scale_a,
1977 int32_t input_effective_scale_b,
1978 int32_t recurrent_effective_scale_a,
1979 int32_t recurrent_effective_scale_b, int32_t n_batch,
1980 int32_t n_cell, int16_t* output);
1981
TEST(uKernels,TwoGateSaturateAddTest)1982 TEST(uKernels, TwoGateSaturateAddTest) {
1983 const std::vector<int8_t> input1 = {1, 2, 3, 4, 55, 66, 77};
1984 const std::vector<int8_t> input2 = {100, 2, 3, 4, 55, 66, 77};
1985 const int32_t input1_zp = 10;
1986 const int32_t input2_zp = -5;
1987 const int32_t multiplier1 = 1347771520;
1988 const int32_t shift1 = -7;
1989 const int32_t multiplier2 = 1047577121;
1990 const int32_t shift2 = -6;
1991 std::vector<int16_t> output(7);
1992
1993 TwoGateSaturatingAdd(input1.data(), input1_zp, input2.data(), input2_zp,
1994 multiplier1, shift1, multiplier2, shift2, 1, 7,
1995 output.data());
1996
1997 const std::vector<int16_t> expected_output = {1, 0, 0, 0, 0, 1, 1};
1998 EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
1999 }
2000
2001 namespace {
2002 // Parameterized test: mean, difference, tolerance.
2003 // Input is constructed as [mean-2*diff, mean-diff, mean+diff, mean+2*diff]
2004 class MeanStddevNormalizationTest
2005 : public testing::TestWithParam<std::tuple<float, float, float>> {};
2006 } // namespace
2007
TEST_P(MeanStddevNormalizationTest,SeparateBatches)2008 TEST_P(MeanStddevNormalizationTest, SeparateBatches) {
2009 const float mean = std::get<0>(GetParam());
2010 const float diff = std::get<1>(GetParam());
2011 const float tolerance = std::get<2>(GetParam());
2012
2013 constexpr int kVectorSize = 4;
2014 const float input[kVectorSize] = {mean - 2 * diff, mean - diff, mean + diff,
2015 mean + 2 * diff};
2016 float output[kVectorSize];
2017 MeanStddevNormalization(input, output, kVectorSize, 1);
2018 std::vector<float> expected_output;
2019 if (diff == 0.0f) {
2020 expected_output.assign({0.0f, 0.0f, 0.0f, 0.0f});
2021 } else {
2022 const float ksqrt16 = std::sqrt(1.6f);
2023 const float ksqrt04 = std::sqrt(0.4f);
2024 expected_output.assign({-ksqrt16, -ksqrt04, ksqrt04, ksqrt16});
2025 }
2026 EXPECT_THAT(output, testing::ElementsAreArray(
2027 ArrayFloatNear(expected_output, tolerance)));
2028 }
2029
2030 INSTANTIATE_TEST_SUITE_P(
2031 uKernels, MeanStddevNormalizationTest,
2032 testing::Values(
2033 std::make_tuple(0.0f, 0.0f, 0.0f), // zero mean, zero variance
2034 std::make_tuple(0.0f, 0.01f, 2.53e-5f), // zero mean, small variance
2035 std::make_tuple(0.0f, 100.0f, 1.20e-7f), // zero mean, large variance
2036 std::make_tuple(0.01f, 0.0f, 0.0f), // small mean, zero variance
2037 std::make_tuple(0.01f, 0.01f, 2.53e-5f), // small mean, small variance
2038 std::make_tuple(0.01f, 100.0f, 1.20e-7f), // small mean, large variance
2039 std::make_tuple(100.0f, 0.0f, 0.0f), // large mean, zero variance
2040 std::make_tuple(100.0f, 0.01f, 1.81e-4f), // large mean, small variance
2041 std::make_tuple(100.0f, 100.0f, 1.20e-7f) // large mean, large variance
2042 ));
2043
TEST(uKernels,MeanStddevNormalizationAllBatches)2044 TEST(uKernels, MeanStddevNormalizationAllBatches) {
2045 constexpr int kVectorSize = 4;
2046 constexpr int kBatchSize = 9;
2047
2048 // None-zero input.
2049 static float input[kVectorSize * kBatchSize] = {
2050 0.0f, 0.0f, 0.0f, 0.0f, // zero mean, zero variance
2051 -0.02f, -0.01f, 0.01f, 0.02f, // zero mean, small variance
2052 -200.0f, -100.0f, 100.0f, 200.0f, // zero mean, large variance
2053 0.01f, 0.01f, 0.01f, 0.01f, // small mean, zero variance
2054 -0.01f, 0.0f, 0.02f, 0.03f, // small mean, small variance
2055 -199.99f, -99.99f, 100.01f, 200.01f, // small mean, large variance
2056 100.0f, 100.0f, 100.0f, 100.0f, // large mean, zero variance
2057 99.98f, 99.99f, 100.01f, 100.02f, // large mean, small variance
2058 -100.0f, 0.0f, 200.0f, 300.0f, // large mean, large variance
2059 };
2060 float output[kVectorSize * kBatchSize];
2061 MeanStddevNormalization(input, output, kVectorSize, kBatchSize);
2062 const float ksqrt16 = std::sqrt(1.6f);
2063 const float ksqrt04 = std::sqrt(0.4f);
2064 const std::vector<float> expected_output = {
2065 0.0f, 0.0f, 0.0f, 0.0f, // zero mean, zero variance
2066 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // zero mean, small variance
2067 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // zero mean, large variance
2068 0.0f, 0.0f, 0.0f, 0.0f, // small mean, zero variance
2069 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // small mean, small variance
2070 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // small mean, large variance
2071 0.0f, 0.0f, 0.0f, 0.0f, // large mean, zero variance
2072 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // large mean, small variance
2073 -ksqrt16, -ksqrt04, ksqrt04, ksqrt16, // large mean, large variance
2074 };
2075 EXPECT_THAT(output, testing::ElementsAreArray(
2076 ArrayFloatNear(expected_output, 1.81e-4f)));
2077 }
2078
TEST(uKernels,MeanStddevNormalizationLargeVector)2079 TEST(uKernels, MeanStddevNormalizationLargeVector) {
2080 const float mean = 100.0f;
2081 const float diff = 1.0f;
2082 // Some large vector that is not a round multiple of any SIMD vector sizes.
2083 // Note this is odd.
2084 constexpr int kVectorSize = 16 * 16 + 16 + 1;
2085
2086 float input[kVectorSize];
2087 // First input is mean.
2088 input[0] = mean;
2089 // Rest is alternating between mean + diff and mean - diff.
2090 for (int i = 1; i < kVectorSize - 1; i += 2) {
2091 input[i + 0] = mean + diff;
2092 input[i + 1] = mean - diff;
2093 }
2094 float output[kVectorSize];
2095 MeanStddevNormalization(input, output, kVectorSize, 1);
2096
2097 float expected_output[kVectorSize];
2098 // First output should be 0.
2099 expected_output[0] = 0.0;
2100 // Rest should be alternating between ±√(N/(N-1)).
2101 const float expected_elem = std::sqrt(static_cast<double>(kVectorSize) /
2102 static_cast<double>(kVectorSize - 1));
2103 for (int i = 1; i < kVectorSize - 1; i += 2) {
2104 expected_output[i + 0] = +expected_elem;
2105 expected_output[i + 1] = -expected_elem;
2106 }
2107 EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), expected_output));
2108 }
2109
2110 } // namespace tensor_utils
2111 } // namespace tflite
2112
2113 #ifdef DOTPROD_BENCHMARKS
2114
2115 // Compile with --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" and
2116 // --copt="-DDOTPROD_BENCHMARKS"
2117 // Run with --benchmarks=all
BM_DotprodBatchOneMultiply(benchmark::State & state)2118 void BM_DotprodBatchOneMultiply(benchmark::State& state) {
2119 const int rows = state.range(0);
2120 const int cols = state.range(1);
2121 const int batch = state.range(2);
2122 const int copies = state.range(3);
2123
2124 // For some benchmarks we make multiple matrix copies. This allows us to
2125 // measure the performance differences of being entirely in cache vs.
2126 // out of cache.
2127 std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2128 for (int i = 0; i < copies; i++) {
2129 datas.push_back(
2130 tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2131 }
2132
2133 int copy = 0;
2134 for (auto _ : state) {
2135 copy = (copy + 1) % datas.size();
2136 auto& data = datas[copy];
2137 for (int i = 0; i < batch; i++) {
2138 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2139 data.matrix.data(), data.rows, data.cols,
2140 data.vectors.data() + (data.cols * i), data.scale_factors.data(), 1,
2141 &data.results[0], 1);
2142 testing::DoNotOptimize(data.results[2]);
2143 }
2144 }
2145 }
2146 BENCHMARK(BM_DotprodBatchOneMultiply)
2147 ->Args({16, 16, 1, 1})
2148 ->Args({16, 16, 4, 1})
2149 ->Args({32, 32, 1, 1})
2150 ->Args({32, 32, 4, 1})
2151 ->Args({64, 64, 1, 1})
2152 ->Args({64, 64, 4, 1})
2153 ->Args({128, 128, 1, 1})
2154 ->Args({128, 128, 4, 1})
2155 ->Args({992, 992, 1, 1})
2156 ->Args({992, 992, 8, 1})
2157 ->Args({1024, 1024, 1, 1})
2158 ->Args({1024, 1024, 1, 8})
2159 ->Args({1024, 1024, 4, 1})
2160 ->Args({1024, 1024, 4, 8})
2161 ->Args({1024, 1024, 8, 1})
2162 ->Args({640, 2048, 1, 1})
2163 ->Args({640, 2048, 4, 1})
2164 ->Args({640, 2048, 8, 1})
2165 ->Args({640, 2048, 8, 8})
2166 ->Args({2048, 2048, 1, 1})
2167 ->Args({2048, 2048, 1, 8})
2168 ->Args({2048, 2048, 8, 1});
2169
BM_DotprodBatchFourMultiply(benchmark::State & state)2170 void BM_DotprodBatchFourMultiply(benchmark::State& state) {
2171 const int rows = state.range(0);
2172 const int cols = state.range(1);
2173 const int batch = state.range(2);
2174 const int copies = state.range(3);
2175
2176 // For some benchmarks we make multiple matrix copies. This allows us to
2177 // measure the performance differences of being entirely in cache vs.
2178 // out of cache.
2179 std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2180 for (int i = 0; i < copies; i++) {
2181 datas.push_back(
2182 tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2183 }
2184
2185 int copy = 0;
2186 for (auto _ : state) {
2187 copy = (copy + 1) % datas.size();
2188 auto& data = datas[copy];
2189 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2190 data.matrix.data(), data.rows, data.cols, data.vectors.data(),
2191 data.scale_factors.data(), data.batch, &data.results[0], 1);
2192 testing::DoNotOptimize(data.results[2]);
2193 }
2194 }
2195 BENCHMARK(BM_DotprodBatchFourMultiply)
2196 ->Args({16, 16, 4, 1})
2197 ->Args({32, 32, 4, 1})
2198 ->Args({64, 64, 4, 1})
2199 ->Args({64, 256, 64, 1})
2200 ->Args({64, 256, 256, 1})
2201 ->Args({64, 256, 1024, 1})
2202 ->Args({64, 256, 12544, 1})
2203 ->Args({128, 128, 2, 1})
2204 ->Args({128, 128, 3, 1})
2205 ->Args({128, 128, 4, 1})
2206 ->Args({128, 128, 5, 1})
2207 ->Args({640, 640, 4, 1})
2208 ->Args({992, 992, 8, 1})
2209 ->Args({1024, 1024, 2, 1})
2210 ->Args({1024, 1024, 3, 1})
2211 ->Args({1024, 1024, 4, 1})
2212 ->Args({1024, 1024, 5, 1})
2213 ->Args({1024, 1024, 8, 1})
2214 ->Args({1024, 1024, 8, 8})
2215 ->Args({1024, 1024, 256, 1})
2216 ->Args({640, 2048, 2, 1})
2217 ->Args({640, 2048, 3, 1})
2218 ->Args({640, 2048, 4, 1})
2219 ->Args({640, 2048, 4, 8})
2220 ->Args({640, 2048, 8, 1})
2221 ->Args({2048, 2048, 3, 1})
2222 ->Args({2048, 2048, 4, 1})
2223 ->Args({2048, 2048, 4, 8})
2224 ->Args({2048, 2048, 5, 1})
2225 ->Args({2048, 2048, 8, 1});
2226
BM_DotprodSparseMultiply(benchmark::State & state)2227 void BM_DotprodSparseMultiply(benchmark::State& state) {
2228 const int rows = state.range(0);
2229 const int cols = state.range(1);
2230 const int batch = state.range(2);
2231
2232 const int copies = state.range(3);
2233
2234 // For some benchmarks we make multiple matrix copies. This allows us to
2235 // measure the performance differences of being entirely in cache vs.
2236 // out of cache.
2237 std::vector<tflite::tensor_utils::MatrixVectorData> datas;
2238 for (int i = 0; i < copies; i++) {
2239 datas.push_back(
2240 tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch));
2241 }
2242
2243 int copy = 0;
2244 for (auto _ : state) {
2245 copy = (copy + 1) % datas.size();
2246 auto& data = datas[copy];
2247 tflite::tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
2248 data.sparse_matrix.data(), data.ledger.data(), data.rows, data.cols,
2249 data.vectors.data(), data.scale_factors.data(), data.batch,
2250 &data.results[0]);
2251 testing::DoNotOptimize(data.results[2]);
2252 }
2253 }
2254 BENCHMARK(BM_DotprodSparseMultiply)
2255 ->Args({128, 128, 1, 1})
2256 ->Args({128, 128, 4, 1})
2257 ->Args({640, 640, 4, 1})
2258 ->Args({992, 992, 8, 1})
2259 ->Args({1024, 1024, 1, 1})
2260 ->Args({1024, 1024, 4, 1})
2261 ->Args({1024, 1024, 8, 1})
2262 ->Args({640, 2048, 1, 1})
2263 ->Args({640, 2048, 4, 1})
2264 ->Args({640, 2048, 8, 1})
2265 ->Args({2048, 2048, 1, 1})
2266 ->Args({2048, 2048, 8, 1});
2267
2268 #endif // DOTPROD_BENCHMARKS
2269