• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
17 
18 #include <algorithm>
19 #include <cstdarg>
20 #include <limits>
21 #include <random>
22 #include <sstream>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include <gtest/gtest.h>
28 #include "tensorflow/lite/experimental/ruy/ruy.h"
29 #include "tensorflow/lite/kernels/cpu_backend_context.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
31 
32 namespace tflite {
33 
34 namespace {
35 
36 using cpu_backend_gemm::Gemm;
37 using cpu_backend_gemm::GemmParams;
38 using cpu_backend_gemm::MatrixParams;
39 using cpu_backend_gemm::QuantizationFlavor;
40 
41 template <typename Scalar>
ToString(const std::vector<Scalar> & vector)42 std::string ToString(const std::vector<Scalar>& vector) {
43   std::stringstream s;
44   if (vector.empty()) {
45     s << "{}";
46   } else {
47     s << "{ " << static_cast<double>(vector[0]);
48     for (int i = 1; i < vector.size(); i++) {
49       s << ", " << static_cast<double>(vector[i]);
50     }
51     s << "}";
52   }
53   return s.str();
54 }
55 
56 template <typename Scalar>
MakeDeterministicPseudoRandomVector(int size,std::vector<Scalar> * vector)57 void MakeDeterministicPseudoRandomVector(int size,
58                                          std::vector<Scalar>* vector) {
59   // Intentionally create a new local random_engine in each invocation,
60   // so pseudorandom values don't depend on invocation order.
61   // Otherwise, test results would be affecting by e.g. filtering.
62   std::default_random_engine random_engine;
63   (void)random_engine();
64   // Do not use std::uniform*_distribution: the values that it
65   // generates are implementation-defined.
66   const double random_min = static_cast<double>(random_engine.min());
67   const double random_max = static_cast<double>(random_engine.max());
68   const double result_min =
69       std::is_floating_point<Scalar>::value
70           ? -1.0
71           : std::max(-256., static_cast<double>(
72                                 std::numeric_limits<Scalar>::lowest()));
73   const double result_max =
74       std::is_floating_point<Scalar>::value
75           ? 1.0
76           : std::min(256.,
77                      static_cast<double>(std::numeric_limits<Scalar>::max()));
78   const double random_scale =
79       (result_max - result_min) / (random_max - random_min);
80 
81   vector->resize(size);
82   for (int i = 0; i < size; i++) {
83     double val = random_scale * (random_engine() - random_min);
84     val = std::max(val,
85                    static_cast<double>(std::numeric_limits<Scalar>::lowest()));
86     val =
87         std::min(val, static_cast<double>(std::numeric_limits<Scalar>::max()));
88     (*vector)[i] = static_cast<Scalar>(val);
89   }
90 }
91 
92 template <typename Scalar>
MakeVectorFilledWithConsecutiveInts(int size,std::vector<Scalar> * vector)93 void MakeVectorFilledWithConsecutiveInts(int size,
94                                          std::vector<Scalar>* vector) {
95   vector->resize(size);
96   EXPECT_LE(size, std::numeric_limits<Scalar>::max());
97   for (int i = 0; i < size; i++) {
98     (*vector)[i] = static_cast<Scalar>(i + 1);
99   }
100 }
101 
102 template <typename Scalar>
Median(const std::vector<Scalar> & vector)103 Scalar Median(const std::vector<Scalar>& vector) {
104   EXPECT_GT(vector.size(), 0);
105   std::vector<Scalar> vector_copy = vector;
106   std::sort(std::begin(vector_copy), std::end(vector_copy));
107   return vector_copy[vector_copy.size() / 2];
108 }
109 
110 template <typename Scalar>
MedianAbs(const std::vector<Scalar> & vector)111 double MedianAbs(const std::vector<Scalar>& vector) {
112   EXPECT_GT(vector.size(), 0);
113   std::vector<double> vector_abs;
114   vector_abs.resize(vector.size());
115   for (int i = 0; i < vector.size(); i++) {
116     vector_abs[i] = std::abs(static_cast<double>(vector[i]));
117   }
118   std::sort(std::begin(vector_abs), std::end(vector_abs));
119   return vector_abs[vector_abs.size() / 2];
120 }
121 
122 template <typename Scalar>
Clamp(const std::vector<Scalar> & src,Scalar clamp_min,Scalar clamp_max,std::vector<Scalar> * dst)123 void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
124            std::vector<Scalar>* dst) {
125   dst->resize(src.size());
126   for (int i = 0; i < src.size(); i++) {
127     (*dst)[i] = std::max(std::min(src[i], clamp_max), clamp_min);
128   }
129 }
130 
131 template <typename AccumScalar, typename DstScalar,
132           QuantizationFlavor quantization_flavor>
Clamp(const GemmParams<AccumScalar,DstScalar,quantization_flavor> & src,DstScalar clamp_min,DstScalar clamp_max,GemmParams<AccumScalar,DstScalar,quantization_flavor> * dst)133 void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
134            DstScalar clamp_min, DstScalar clamp_max,
135            GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
136   *dst = src;
137   dst->clamp_min = clamp_min;
138   dst->clamp_max = clamp_max;
139 }
140 
141 struct ErrorStats {
142   int size;
143   double scale_factor;
144   double max_abs_diff;
145   double mean_abs_diff;
146   double abs_mean_diff;
147 };
148 
149 template <typename Scalar>
ComputeErrorStats(const std::vector<Scalar> & actual,const std::vector<Scalar> & expected,ErrorStats * error_stats)150 void ComputeErrorStats(const std::vector<Scalar>& actual,
151                        const std::vector<Scalar>& expected,
152                        ErrorStats* error_stats) {
153   double max_abs_diff = 0;
154   double sum_abs_diff = 0;
155   double sum_diff = 0;
156   double max_abs_expected = 0;
157   EXPECT_EQ(actual.size(), expected.size());
158   for (int i = 0; i < actual.size(); i++) {
159     double actual_val = static_cast<double>(actual[i]);
160     double expected_val = static_cast<double>(expected[i]);
161     double diff = actual_val - expected_val;
162     max_abs_expected = std::max(max_abs_expected, std::abs(expected_val));
163     sum_diff += diff;
164     sum_abs_diff += std::abs(diff);
165     max_abs_diff = std::max(max_abs_diff, std::abs(diff));
166   }
167   error_stats->scale_factor = max_abs_expected;
168   error_stats->max_abs_diff = max_abs_diff;
169   error_stats->mean_abs_diff = sum_abs_diff / actual.size();
170   error_stats->abs_mean_diff = std::abs(sum_diff / actual.size());
171   error_stats->size = actual.size();
172 }
173 
174 template <typename AccumScalar, typename DstScalar>
CheckErrorStats(const ErrorStats & error_stats,int accumulation_depth)175 bool CheckErrorStats(const ErrorStats& error_stats, int accumulation_depth) {
176   double tolerated_relative_max_abs_diff = 0;
177   double tolerated_relative_mean_abs_diff = 0;
178   double tolerated_relative_abs_mean_diff = 0;
179 
180   double inverse_size = 1. / error_stats.size;
181 
182   if (std::is_floating_point<AccumScalar>::value) {
183     // Somewhat naive requirement: the worst case should be epsilons
184     // adding up towards the same direction, on values of same magnitude.
185     tolerated_relative_max_abs_diff =
186         accumulation_depth * std::numeric_limits<DstScalar>::epsilon();
187     // Naive interpretation of the Central Limit Theorem is the rationale
188     // for the sqrt here. We haven't even worked out the correct scale factor,
189     // or how applicable that theorem is here (the random variables being added
190     // might not be mutually independent).
191     tolerated_relative_mean_abs_diff =
192         std::sqrt(static_cast<double>(accumulation_depth)) *
193         std::numeric_limits<DstScalar>::epsilon();
194     // Unbiasing requirement: we require the bias, abs_mean_diff, to be much
195     // smaller than the mean_abs_diff, except when there are very few values.
196     tolerated_relative_abs_mean_diff =
197         tolerated_relative_mean_abs_diff * std::sqrt(inverse_size);
198   } else {
199     // In quantized arithmetic, tolerate minor rounding differences, resulting
200     // in off-by-one errors (tolerated_relative_max_abs_diff = 1), as long
201     // as they are rare (tolerated_relative_mean_abs_diff) and unbiased
202     // (tolerated_relative_abs_mean_diff).
203     tolerated_relative_max_abs_diff = 1;
204     // Naively require mean_abs_diff and abs_mean_diff to converge to zero
205     // as size gets large. We don't know at all how quick that convergence
206     // should be: this is just based on trial-and-error and striking a
207     // compromise between something that works and something that's simple
208     // enough code that doesn't feel too ad-hoc. As above in the float path,
209     // abs_mean_diff is subject to a stricter requirement as it is a bias.
210     tolerated_relative_mean_abs_diff = std::sqrt(inverse_size) * 0.5;
211     tolerated_relative_abs_mean_diff = inverse_size * 2.;
212   }
213 
214   double tolerated_max_abs_diff =
215       tolerated_relative_max_abs_diff * error_stats.scale_factor;
216   double tolerated_mean_abs_diff =
217       tolerated_relative_mean_abs_diff * error_stats.scale_factor;
218   double tolerated_abs_mean_diff =
219       tolerated_relative_abs_mean_diff * error_stats.scale_factor;
220 
221   EXPECT_LE(error_stats.max_abs_diff, tolerated_max_abs_diff);
222   EXPECT_LE(error_stats.mean_abs_diff, tolerated_mean_abs_diff);
223   EXPECT_LE(error_stats.abs_mean_diff, tolerated_abs_mean_diff);
224 
225   return error_stats.max_abs_diff <= tolerated_max_abs_diff &&
226          error_stats.mean_abs_diff <= tolerated_mean_abs_diff &&
227          error_stats.abs_mean_diff <= tolerated_abs_mean_diff;
228 }
229 
230 template <typename AccumScalar, typename DstScalar>
CheckErrorForAccumulation(int accumulation_depth,const std::vector<DstScalar> & actual,const std::vector<DstScalar> & expected)231 void CheckErrorForAccumulation(int accumulation_depth,
232                                const std::vector<DstScalar>& actual,
233                                const std::vector<DstScalar>& expected) {
234   ErrorStats error_stats;
235   ComputeErrorStats(actual, expected, &error_stats);
236   bool success =
237       CheckErrorStats<AccumScalar, DstScalar>(error_stats, accumulation_depth);
238   EXPECT_TRUE(success) << "Actual vector\n"
239                        << ToString(actual) << "\ndiffers from expected vector\n"
240                        << ToString(expected) << "\n";
241 }
242 
243 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
244           typename DstScalar, QuantizationFlavor quantization_flavor>
PerformGemmThenCompareResultsThenAgainWithClamping(const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,const std::vector<DstScalar> & expected,CpuBackendContext * cpu_backend_context)245 void PerformGemmThenCompareResultsThenAgainWithClamping(
246     const MatrixParams<LhsScalar>& lhs_params,
247     const std::vector<LhsScalar>& lhs_data,
248     const MatrixParams<RhsScalar>& rhs_params,
249     const std::vector<RhsScalar>& rhs_data,
250     const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
251     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
252     const std::vector<DstScalar>& expected,
253     CpuBackendContext* cpu_backend_context) {
254   const int accumulation_depth = lhs_params.cols;
255   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
256        dst_data->data(), params, cpu_backend_context);
257   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
258                                          expected);
259   DstScalar expected_median = Median(expected);
260   std::vector<DstScalar> expected_with_clamp;
261   GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
262   DstScalar clamp_min, clamp_max;
263 
264   clamp_min = std::numeric_limits<DstScalar>::lowest();
265   clamp_max = expected_median;
266   Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
267   Clamp(params, clamp_min, clamp_max, &params_with_clamp);
268   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
269        dst_data->data(), params_with_clamp, cpu_backend_context);
270   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
271                                          expected_with_clamp);
272 
273   clamp_min = expected_median;
274   clamp_max = std::numeric_limits<DstScalar>::max();
275   Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
276   Clamp(params, clamp_min, clamp_max, &params_with_clamp);
277   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
278        dst_data->data(), params_with_clamp, cpu_backend_context);
279   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
280                                          expected_with_clamp);
281 }
282 
283 // When generating testcases for a quantized GEMM, it's not trivial to
284 // pick multiplier exponents: a too low value will result in too many zeros,
285 // a too high value will result in too many large clamped values, in both
286 // cases testing coverage is harmed. Therefore to ensure good testing coverage
287 // we must find a multiplier exponent that's just right.  It would be possible
288 // to do so by analysis of the random distribution of values in the result
289 // matrix. That however would require some mathematical work that we haven't
290 // done so far. Until that is done, the best that we can do is to search for
291 // a good exponent value by trial-and-error. This is expensive, as each try
292 // requires computing a whole GEMM. This is thus probably a major contribution
293 // to the overall latency of this tesat. To partially mitigate that,
294 // we use a bisection to reduce the required number of tries.
295 //
296 // This function is recursive. The bisect_min and bisect_max arguments
297 // are the current bisection bounds. It performs a Gemm with the mid-point,
298 // named bisect_mid, as the multiplier exponent. Based on whether the values
299 // in the resulting matrix are rather too low or too large in absolute
300 // value, it then recurses into the corresponding half of the bisection range.
301 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
302           typename DstScalar>
BisectReasonableMultiplierExponent(int bisect_min,int bisect_max,const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar> & params,CpuBackendContext * cpu_backend_context)303 int BisectReasonableMultiplierExponent(
304     int bisect_min, int bisect_max, const MatrixParams<LhsScalar>& lhs_params,
305     const std::vector<LhsScalar>& lhs_data,
306     const MatrixParams<RhsScalar>& rhs_params,
307     const std::vector<RhsScalar>& rhs_data,
308     const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
309     const GemmParams<AccumScalar, DstScalar>& params,
310     CpuBackendContext* cpu_backend_context) {
311   if (bisect_min == bisect_max) {
312     return bisect_min;
313   }
314   // Compute the midpoint as the floor of the average of bisect_min and
315   // bisect_max. As C++ integer division is rounding towards zero and our values
316   // may be of any sign, it is not trivial to implement this using only integer
317   // arithmetic.
318   int bisect_mid =
319       static_cast<int>(std::floor(0.5 * (bisect_min + bisect_max)));
320   GemmParams<AccumScalar, DstScalar> params_copy(params);
321   params_copy.multiplier_exponent = bisect_mid;
322   double clamp_abs = std::max(std::abs(static_cast<double>(params.clamp_min)),
323                               std::abs(static_cast<double>(params.clamp_max)));
324   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
325        dst_data->data(), params_copy, cpu_backend_context);
326   double median_abs = MedianAbs(*dst_data);
327   if (median_abs < 0.25 * clamp_abs) {
328     return BisectReasonableMultiplierExponent(
329         bisect_mid + 1, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
330         dst_params, dst_data, params_copy, cpu_backend_context);
331   } else {
332     return BisectReasonableMultiplierExponent(
333         bisect_min, bisect_mid, lhs_params, lhs_data, rhs_params, rhs_data,
334         dst_params, dst_data, params_copy, cpu_backend_context);
335   }
336 }
337 
338 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
339           typename DstScalar, QuantizationFlavor quantization_flavor>
ReferenceGemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)340 void ReferenceGemm(
341     const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
342     const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
343     const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
344     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
345     CpuBackendContext* context) {
346   ruy::Matrix<LhsScalar> ruy_lhs;
347   ruy::Matrix<RhsScalar> ruy_rhs;
348   ruy::Matrix<DstScalar> ruy_dst;
349   cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
350   cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
351   cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
352 
353   ruy::BasicSpec<AccumScalar, DstScalar> ruy_spec;
354   cpu_backend_gemm::detail::MakeRuySpec(params, &ruy_spec);
355 
356   ruy::Mul<ruy::Path::kReference>(ruy_lhs, ruy_rhs, ruy_spec,
357                                   context->ruy_context(), &ruy_dst);
358 }
359 
360 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
361           typename DstScalar>
TestSomeGemm(int rows,int depth,int cols,const std::vector<DstScalar> & golden)362 void TestSomeGemm(int rows, int depth, int cols,
363                   const std::vector<DstScalar>& golden) {
364   CpuBackendContext cpu_backend_context;
365   std::default_random_engine random_engine;
366   cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
367 
368   const bool use_golden = !golden.empty();
369 
370   std::vector<LhsScalar> lhs_data;
371   std::vector<RhsScalar> rhs_data;
372   std::vector<AccumScalar> bias_data;
373   std::vector<DstScalar> dst_data;
374   if (use_golden) {
375     MakeVectorFilledWithConsecutiveInts(rows * depth, &lhs_data);
376     MakeVectorFilledWithConsecutiveInts(depth * cols, &rhs_data);
377     MakeVectorFilledWithConsecutiveInts(rows, &bias_data);
378   } else {
379     MakeDeterministicPseudoRandomVector(rows * depth, &lhs_data);
380     MakeDeterministicPseudoRandomVector(depth * cols, &rhs_data);
381     MakeDeterministicPseudoRandomVector(rows, &bias_data);
382   }
383   MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
384 
385   MatrixParams<LhsScalar> lhs_params;
386   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
387   lhs_params.rows = rows;
388   lhs_params.cols = depth;
389   if (!std::is_floating_point<LhsScalar>::value) {
390     lhs_params.zero_point = 1;
391     if (!use_golden) {
392       lhs_params.zero_point += random_engine() % 8;
393     }
394   }
395 
396   MatrixParams<RhsScalar> rhs_params;
397   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
398   rhs_params.rows = depth;
399   rhs_params.cols = cols;
400   if (!std::is_floating_point<RhsScalar>::value) {
401     rhs_params.zero_point = 1;
402     if (!use_golden) {
403       rhs_params.zero_point += random_engine() % 8;
404     }
405   }
406 
407   MatrixParams<DstScalar> dst_params;
408   dst_params.order = cpu_backend_gemm::Order::kColMajor;
409   dst_params.rows = rows;
410   dst_params.cols = cols;
411   if (!std::is_floating_point<DstScalar>::value) {
412     dst_params.zero_point = 1;
413     if (!use_golden) {
414       dst_params.zero_point += random_engine() % 8;
415     }
416   }
417 
418   GemmParams<AccumScalar, DstScalar> params;
419   if (use_golden || (random_engine() % 2)) {
420     // cpu_backend_gemm supports bias=null only in the float path. Test that
421     // in 50% of float testcases.
422     params.bias = bias_data.data();
423   }
424   static constexpr std::int32_t kMultiplierFixedpointMin = 1234567890;
425   static constexpr std::int32_t kMultiplierFixedpointMax = 1987654321;
426   if (!std::is_floating_point<AccumScalar>::value) {
427     // some large int32 value. Not being a multiple of a large
428     // power of two helps testing rounding behavior.
429     params.multiplier_fixedpoint = kMultiplierFixedpointMin;
430     // Now find a suitable value for multiplier_exponent.
431     // It needs to be low enough for a substantial amount of dst values
432     // to avoid getting clamped.
433     int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
434     // We don't increase test coverage by using positive multipliers,
435     // and using very large positive multipliers may at the moment
436     // result in overflow in some paths.
437     // TODO(benoitjacob): fix that.
438     int bisect_max = 0;
439     params.multiplier_exponent = BisectReasonableMultiplierExponent(
440         bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
441         dst_params, &dst_data, params, &cpu_backend_context);
442   }
443 
444   std::vector<DstScalar> expected;
445   if (use_golden) {
446     EXPECT_EQ(golden.size(), dst_data.size());
447     expected = golden;
448   } else {
449     expected.resize(dst_data.size());
450     ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
451                   dst_params, expected.data(), params, &cpu_backend_context);
452   }
453 
454   PerformGemmThenCompareResultsThenAgainWithClamping(
455       lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data, params,
456       expected, &cpu_backend_context);
457 
458   if (!use_golden && !std::is_floating_point<AccumScalar>::value) {
459     // Try with per-channel quantized multipliers.
460     std::vector<AccumScalar> multiplier_fixedpoint_perchannel(rows);
461     std::vector<int> multiplier_exponent_perchannel(rows);
462     for (int i = 0; i < rows; i++) {
463       multiplier_fixedpoint_perchannel[i] =
464           kMultiplierFixedpointMin +
465           (random_engine() %
466            (kMultiplierFixedpointMax + 1 - kMultiplierFixedpointMin));
467       const int exponent_min = params.multiplier_exponent - 2;
468       const int exponent_max = params.multiplier_exponent + 2;
469       multiplier_exponent_perchannel[i] =
470           exponent_min + (random_engine() % (exponent_max + 1 - exponent_min));
471     }
472     static constexpr QuantizationFlavor perchannel_flavor =
473         std::is_floating_point<AccumScalar>::value
474             ? QuantizationFlavor::kFloatingPoint
475             : QuantizationFlavor::kIntegerWithPerRowMultiplier;
476     GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
477     params_perchannel.bias = params.bias;
478     params_perchannel.clamp_min = params.clamp_min;
479     params_perchannel.clamp_max = params.clamp_max;
480     params_perchannel.multiplier_fixedpoint_perchannel =
481         multiplier_fixedpoint_perchannel.data();
482     params_perchannel.multiplier_exponent_perchannel =
483         multiplier_exponent_perchannel.data();
484     ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
485                   dst_params, expected.data(), params_perchannel,
486                   &cpu_backend_context);
487     PerformGemmThenCompareResultsThenAgainWithClamping(
488         lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data,
489         params_perchannel, expected, &cpu_backend_context);
490   }
491 }
492 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Float)493 TEST(CpuBackendGemmSimpleTestAgainstGolden, Float) {
494   TestSomeGemm<float, float, float, float>(2, 3, 4,
495                                            {15, 34, 33, 79, 51, 124, 69, 169});
496 }
497 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Uint8)498 TEST(CpuBackendGemmSimpleTestAgainstGolden, Uint8) {
499   TestSomeGemm<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>(
500       5, 2, 3, {2, 4, 6, 7, 9, 3, 10, 16, 22, 29, 4, 15, 26, 37, 48});
501 }
502 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8)503 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8) {
504   TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int8_t>(
505       2, 6, 3, {13, 32, 31, 81, 50, 127});
506 }
507 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8Int16)508 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8Int16) {
509   TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int16_t>(
510       3, 5, 4, {19, 48, 77, 48, 149, 250, 76, 249, 422, 105, 350, 595});
511 }
512 
513 template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
514           typename tDstScalar>
515 struct TypesTuple {
516   using LhsScalar = tLhsScalar;
517   using RhsScalar = tRhsScalar;
518   using AccumScalar = tAccumScalar;
519   using DstScalar = tDstScalar;
520 };
521 
522 template <typename TypesTupleType>
TestRandomGemms(const std::vector<std::tuple<int,int,int>> & shapes)523 void TestRandomGemms(const std::vector<std::tuple<int, int, int>>& shapes) {
524   using LhsScalar = typename TypesTupleType::LhsScalar;
525   using RhsScalar = typename TypesTupleType::RhsScalar;
526   using AccumScalar = typename TypesTupleType::AccumScalar;
527   using DstScalar = typename TypesTupleType::DstScalar;
528   for (const auto& shape : shapes) {
529     int rows = std::get<0>(shape);
530     int depth = std::get<1>(shape);
531     int cols = std::get<2>(shape);
532     TestSomeGemm<LhsScalar, RhsScalar, AccumScalar, DstScalar>(rows, depth,
533                                                                cols, {});
534   }
535 }
536 
537 template <typename TypesTupleType>
538 class CpuBackendGemmTest : public testing::Test {};
539 
540 TYPED_TEST_SUITE_P(CpuBackendGemmTest);
541 
542 typedef ::testing::Types<
543     TypesTuple<float, float, float, float>,
544     TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>,
545     TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int8_t>,
546     TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int16_t>,
547     TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::int8_t>>
548     CpuBackendGemmTestInstantiations;
549 
550 TYPED_TEST_SUITE(CpuBackendGemmTest, CpuBackendGemmTestInstantiations);
551 
TYPED_TEST(CpuBackendGemmTest,Square)552 TYPED_TEST(CpuBackendGemmTest, Square) {
553   std::vector<std::tuple<int, int, int>> shapes;
554   for (int size = 1; size < 50; size++) {
555     shapes.push_back(std::make_tuple(size, size, size));
556   }
557   TestRandomGemms<TypeParam>(shapes);
558 }
559 
TYPED_TEST(CpuBackendGemmTest,SquarePowerOfTwo)560 TYPED_TEST(CpuBackendGemmTest, SquarePowerOfTwo) {
561   std::vector<std::tuple<int, int, int>> shapes;
562   for (int size = 64; size <= 128; size *= 2) {
563     shapes.push_back(std::make_tuple(size, size, size));
564   }
565   TestRandomGemms<TypeParam>(shapes);
566 }
567 
TYPED_TEST(CpuBackendGemmTest,MatrixTimesVector)568 TYPED_TEST(CpuBackendGemmTest, MatrixTimesVector) {
569   std::vector<std::tuple<int, int, int>> shapes;
570   for (int size = 1; size < 200; size++) {
571     shapes.push_back(std::make_tuple(size, size, 1));
572   }
573   TestRandomGemms<TypeParam>(shapes);
574 }
575 
TYPED_TEST(CpuBackendGemmTest,VectorTimesMatrix)576 TYPED_TEST(CpuBackendGemmTest, VectorTimesMatrix) {
577   std::vector<std::tuple<int, int, int>> shapes;
578   for (int size = 1; size < 200; size++) {
579     shapes.push_back(std::make_tuple(1, size, size));
580   }
581   TestRandomGemms<TypeParam>(shapes);
582 }
583 
TYPED_TEST(CpuBackendGemmTest,MatrixTimesNarrow)584 TYPED_TEST(CpuBackendGemmTest, MatrixTimesNarrow) {
585   std::vector<std::tuple<int, int, int>> shapes;
586   for (int size = 1; size < 50; size++) {
587     shapes.push_back(std::make_tuple(size, size, 2));
588     shapes.push_back(std::make_tuple(size, size, 3));
589     shapes.push_back(std::make_tuple(size, size, 4));
590     shapes.push_back(std::make_tuple(size, size, 8));
591   }
592   TestRandomGemms<TypeParam>(shapes);
593 }
594 
TYPED_TEST(CpuBackendGemmTest,Rectangular)595 TYPED_TEST(CpuBackendGemmTest, Rectangular) {
596   std::vector<std::tuple<int, int, int>> shapes;
597   for (int size = 1; size < 50; size++) {
598     shapes.push_back(std::make_tuple(size, size + 5, size + 1));
599     shapes.push_back(std::make_tuple(size + 10, size + 2, size));
600   }
601   TestRandomGemms<TypeParam>(shapes);
602 }
603 
TYPED_TEST(CpuBackendGemmTest,HighlyRectangular)604 TYPED_TEST(CpuBackendGemmTest, HighlyRectangular) {
605   std::vector<std::tuple<int, int, int>> shapes;
606   for (int size = 1; size <= 1000; size *= 10) {
607     shapes.push_back(std::make_tuple(size, 10, 10));
608     shapes.push_back(std::make_tuple(10, size, 10));
609     shapes.push_back(std::make_tuple(10, 10, size));
610   }
611   TestRandomGemms<TypeParam>(shapes);
612 }
613 
TYPED_TEST(CpuBackendGemmTest,InnerProduct)614 TYPED_TEST(CpuBackendGemmTest, InnerProduct) {
615   std::vector<std::tuple<int, int, int>> shapes;
616   for (int size = 1; size < 200; size++) {
617     shapes.push_back(std::make_tuple(1, size, 1));
618   }
619   TestRandomGemms<TypeParam>(shapes);
620 }
621 
TYPED_TEST(CpuBackendGemmTest,OuterProduct)622 TYPED_TEST(CpuBackendGemmTest, OuterProduct) {
623   std::vector<std::tuple<int, int, int>> shapes;
624   for (int size = 1; size < 100; size++) {
625     shapes.push_back(std::make_tuple(size, 1, size));
626   }
627   TestRandomGemms<TypeParam>(shapes);
628 }
629 
630 }  // namespace
631 
632 }  // namespace tflite
633 
main(int argc,char ** argv)634 int main(int argc, char** argv) {
635   ::testing::InitGoogleTest(&argc, argv);
636   return RUN_ALL_TESTS();
637 }
638