• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite FULLY_CONNECTED op.
16 
17 #include "tensorflow/lite/kernels/fully_connected.h"
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include <algorithm>
23 #include <initializer_list>
24 #include <limits>
25 #include <map>
26 #include <memory>
27 #include <random>
28 #include <string>
29 #include <vector>
30 
31 #include <gmock/gmock.h>
32 #include <gtest/gtest.h>
33 #include "absl/memory/memory.h"
34 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
35 #include "tensorflow/lite/core/api/op_resolver.h"
36 #include "tensorflow/lite/interpreter.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/test_util.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/string_type.h"
41 
42 namespace tflite {
43 namespace {
44 
45 using ::testing::ElementsAre;
46 using ::testing::ElementsAreArray;
47 
48 static float fully_connected_input[] = {
49     0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
50     0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
51     0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
52     0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
53     0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
54     0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
55     0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
56     0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
57     0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
58     0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
59     0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
60     0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
61     0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
62     0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
63     0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
64     0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
65     0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
66     0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
67     0.921330, 0.139902};
68 
69 static float fully_connected_golden_output[] = {
70     0,        0.0732134,   0,        0,          0,         0.280859,
71     0,        0.128927,    0,        0.0777251,  0,         0.270268,
72     0.271435, 0.0173503,   0.335465, 0.235562,
73 
74     0,        0.0745866,   0,        0.051611,   0,         0.253876,
75     0,        0.0814873,   0,        0.104104,   0,         0.248529,
76     0.264194, 0,           0.302973, 0.166252,
77 
78     0,        0.0170409,   0,        0.0509851,  0,         0.212834,
79     0,        0.0208326,   0,        0.129932,   0.203978,  0.103428,
80     0.298051, 0,           0.332233, 0.00445903,
81 
82     0,        0.125246,    0,        0.0735336,  0,         0.0910256,
83     0,        0,           0,        0.18933,    0.378111,  0.0712443,
84     0.277298, 0.0123414,   0.267454, 0,
85 
86     0,        0.14687,     0,        0.155495,   0.0300215, 0.147256,
87     0,        0,           0,        0.156412,   0.434914,  0.0461529,
88     0.246508, 0,           0.363138, 0,
89 
90     0,        0,           0,        0.0212949,  0,         0.301708,
91     0,        0.35497,     0,        0.406223,   0.0260211, 0.049195,
92     0.197161, 0,           0.37316,  0,
93 
94     0,        0.221783,    0,        0,          0.0116515, 0.281945,
95     0,        0,           0,        0,          0.285626,  0.181773,
96     0.296401, 0.170452,    0.367135, 0.142597,
97 
98     0,        0,           0,        0,          0,         0.418886,
99     0,        0.291063,    0,        0.227541,   0.0424759, 0.27589,
100     0.398286, 0.177146,    0.40359,  0.121452,
101 
102     0,        0.0834884,   0,        0,          0,         0.287441,
103     0,        0.0046838,   0,        0.0122087,  0,         0.217376,
104     0.140183, 0.0948412,   0.436677, 0.0589876,
105 
106     0,        0.0289969,   0,        0.0921397,  0,         0.396802,
107     0,        0.0126157,   0,        0.0968433,  0,         0.172271,
108     0.173295, 0.0664741,   0.53645,  0.00915603,
109 
110     0,        0,           0,        0,          0,         0.147942,
111     0,        0.263795,    0,        0.39782,    0,         0.382435,
112     0.561072, 0.0579847,   0.145712, 0.13508,
113 
114     0,        0,           0,        0.16382,    0,         0.322294,
115     0,        0.163798,    0,        0.405211,   0.367953,  0.076852,
116     0.342473, 0.0834118,   0.377537, 0,
117 
118     0,        0.206,       0,        0,          0,         0.375769,
119     0,        0,           0,        0,          0,         0.125165,
120     0,        0.105591,    0.52055,  0.0536445,
121 
122     0,        0.259261,    0,        0,          0,         0.247707,
123     0,        0,           0,        0,          0,         0.215862,
124     0.149153, 0.224678,    0.359519, 0.129419,
125 
126     0,        0.17611,     0,        0.280895,   0,         0.576484,
127     0,        0.000418848, 0,        0,          0,         0.151112,
128     0.211902, 0,           0.566341, 0.106305,
129 
130     0,        0.0246284,   0,        0,          0,         0.196267,
131     0,        0.0248624,   0,        0.265635,   0,         0.436199,
132     0.408079, 0.134514,    0.328489, 0.411368};
133 
134 class BaseFullyConnectedOpModel : public SingleOpModel {
135  public:
136   // TODO(ahentz): test different activation types too.
BaseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & output={TensorType_FLOAT32},bool keep_num_dims=false,bool bias_tensor_optional=false,ActivationFunctionType activation_func=ActivationFunctionType_RELU,FullyConnectedOptionsWeightsFormat weights_format=FullyConnectedOptionsWeightsFormat_DEFAULT,bool add_bias_for_quantized=true)137   BaseFullyConnectedOpModel(
138       TfLiteRegistration* registration, int units, int batches,
139       const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
140       bool keep_num_dims = false, bool bias_tensor_optional = false,
141       ActivationFunctionType activation_func = ActivationFunctionType_RELU,
142       FullyConnectedOptionsWeightsFormat weights_format =
143           FullyConnectedOptionsWeightsFormat_DEFAULT,
144       bool add_bias_for_quantized = true)
145       : batches_(batches), units_(units) {
146     int total_input_size = 1;
147     for (size_t i = 0; i < input.shape.size(); ++i) {
148       total_input_size *= input.shape[i];
149     }
150     input_size_ = total_input_size / batches_;
151 
152     input_ = AddInput(input);
153     if (input.type == TensorType_INT16) {
154       weights_ = AddInput({TensorType_INT8, {units_, input_size_}, -63.5, 64});
155     } else {
156       weights_ =
157           AddInput({input.type, {units_, input_size_}, input.min, input.max});
158     }
159 
160     if (bias_tensor_optional) {
161       bias_ = AddNullInput();
162     } else if (input.type == TensorType_FLOAT32) {
163       bias_ = AddInput({TensorType_FLOAT32, {units_}});
164     } else if (add_bias_for_quantized) {
165       // This is a quantized version. The scale of 'bias' depends on the scales
166       // of input and filter. Supposedly this is correctly set during quantized
167       // training.
168       auto bias_scale = GetScale(input_) * GetScale(weights_);
169       if (input.type == TensorType_INT16) {
170         TensorData bias{TensorType_INT64, {units_}, 0, 0, bias_scale};
171         bias_ = AddInput(bias);
172       } else {
173         TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
174         bias_ = AddInput(bias);
175       }
176     }
177 
178     output_ = AddOutput(output);
179     if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) {
180       AddOutput({TensorType_UINT8, input.shape});
181     }
182 
183     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
184                  BuiltinOptions_FullyConnectedOptions,
185                  CreateFullyConnectedOptions(builder_, activation_func,
186                                              weights_format, keep_num_dims)
187                      .Union());
188     resolver_ = absl::make_unique<SingleOpResolver>(
189         BuiltinOperator_FULLY_CONNECTED, registration);
190     std::vector<std::vector<int>> inputs = {GetShape(input_),
191                                             GetShape(weights_)};
192     if (add_bias_for_quantized) {
193       inputs.push_back((bias_ == kTfLiteOptionalTensor) ? std::vector<int>()
194                                                         : GetShape(bias_));
195     }
196     BuildInterpreter(inputs);
197   }
198 
input_size()199   int input_size() { return input_size_; }
num_units()200   int num_units() { return units_; }
num_batches()201   int num_batches() { return batches_; }
202 
203  protected:
204   int input_;
205   int weights_;
206   int bias_;
207   int output_;
208 
209   int batches_;
210   int units_;
211   int input_size_;
212 };
213 
214 class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
215  public:
216   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
217 
SetBias(const std::vector<float> & f)218   void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
219 
SetWeights(const std::vector<float> & f)220   void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
221 
SetInput(const std::vector<float> & data)222   void SetInput(const std::vector<float>& data) {
223     PopulateTensor(input_, data);
224   }
SetInput(int offset,float * begin,float * end)225   void SetInput(int offset, float* begin, float* end) {
226     PopulateTensor(input_, offset, begin, end);
227   }
228 
GetOutput()229   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()230   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
231 };
232 
233 class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
234  public:
235   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
236 
SetBias(const std::vector<float> & data)237   void SetBias(const std::vector<float>& data) {
238     QuantizeAndPopulate<int32_t>(bias_, data);
239   }
SetBias64(const std::vector<float> & data)240   void SetBias64(const std::vector<float>& data) {
241     QuantizeAndPopulate<int64_t>(bias_, data);
242   }
243   template <typename T>
SetWeights(const std::vector<float> & data)244   void SetWeights(const std::vector<float>& data) {
245     QuantizeAndPopulate<T>(weights_, data);
246   }
247 
248   template <typename T>
ShuffleAndSetWeights(const std::vector<float> & data,int input_depth,int output_depth)249   void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
250                             int output_depth) {
251     std::vector<float> shuffled_data(data.size());
252     CHECK_EQ(input_depth % 16, 0);
253     CHECK_EQ(output_depth % 4, 0);
254     float* shuffled_data_ptr = shuffled_data.data();
255     for (int block_o = 0; block_o < output_depth; block_o += 4) {
256       for (int block_i = 0; block_i < input_depth; block_i += 16) {
257         for (int o = 0; o < 4; o++) {
258           for (int i = 0; i < 16; i++) {
259             *shuffled_data_ptr++ =
260                 data[(block_o + o) * input_depth + block_i + i];
261           }
262         }
263       }
264     }
265     TfLiteTensor* t = interpreter_->tensor(weights_);
266     auto quantized_data =
267         Quantize<T>(shuffled_data, t->params.scale, t->params.zero_point);
268     for (T& q : quantized_data) {
269       q ^= 0x80;
270     }
271     PopulateTensor(weights_, 0, quantized_data.data(),
272                    quantized_data.data() + quantized_data.size());
273   }
274 
275   template <typename T>
SetInput(const std::vector<float> & data)276   void SetInput(const std::vector<float>& data) {
277     QuantizeAndPopulate<T>(input_, data);
278   }
279 
280   template <typename T>
GetOutput()281   std::vector<T> GetOutput() {
282     return ExtractVector<T>(output_);
283   }
284 
285   template <typename T>
GetDequantizedOutput()286   std::vector<float> GetDequantizedOutput() {
287     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
288                          GetZeroPoint(output_));
289   }
290 };
291 
292 // In the hybrid model the weights are quantized (to uint8). But the bias,
293 // input (and output) are expected to be in float precision.
294 class HybridFullyConnectedOpModel : public SingleOpModel {
295  public:
HybridFullyConnectedOpModel(int units,int batches,const TensorData & input,const TensorData & weights,const TensorData & output={TensorType_FLOAT32},bool asymmetric_inputs=false,int num_threads=1)296   HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
297                               const TensorData& weights,
298                               const TensorData& output = {TensorType_FLOAT32},
299                               bool asymmetric_inputs = false,
300                               int num_threads = 1)
301       : batches_(batches), units_(units) {
302     int total_input_size = 1;
303     for (size_t i = 0; i < input.shape.size(); ++i) {
304       total_input_size *= input.shape[i];
305     }
306     input_size_ = total_input_size / batches_;
307 
308     input_ = AddInput(input);
309     weights_ = AddInput(weights);
310 
311     TensorData bias{TensorType_FLOAT32, {units_}};
312     bias_ = AddInput(bias);
313 
314     output_ = AddOutput(output);
315 
316     auto options = CreateFullyConnectedOptions(
317                        builder_, ActivationFunctionType_RELU,
318                        tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
319                        false, asymmetric_inputs)
320                        .Union();
321     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
322                  BuiltinOptions_FullyConnectedOptions, options);
323     resolver_ = absl::make_unique<SingleOpResolver>(
324         BuiltinOperator_FULLY_CONNECTED,
325         ops::builtin::Register_FULLY_CONNECTED_PIE());
326     BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
327                      num_threads, /*allow_fp32_relax_to_fp16=*/false,
328                      /*apply_delegate=*/false);
329   }
SetBias(const std::vector<float> & f)330   void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
SetWeights(const std::vector<float> & data)331   void SetWeights(const std::vector<float>& data) {
332     SymmetricQuantizeAndPopulate(weights_, data);
333   }
334 
SetSignedWeights(std::initializer_list<float> f)335   void SetSignedWeights(std::initializer_list<float> f) {
336     SignedSymmetricQuantizeAndPopulate(weights_, f);
337   }
338 
SetInput(const std::vector<float> & f)339   void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
GetOutput()340   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()341   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
342 
input_size()343   int input_size() { return input_size_; }
num_units()344   int num_units() { return units_; }
num_batches()345   int num_batches() { return batches_; }
346 
347  protected:
348   int input_;
349   int weights_;
350   int bias_;
351   int output_;
352 
353   int batches_;
354   int units_;
355   int input_size_;
356 };
357 
358 const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
359     {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
360     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
361     {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
362 });
363 
364 class FloatFullyConnectedOpTest : public SingleOpTest {
365  protected:
GetKernelMap()366   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
367     return *kKernelMap;
368   }
369 };
370 
371 const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
372     {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
373     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
374 });
375 
376 class QuantizedFullyConnectedOpTest : public SingleOpTest {
377  protected:
GetKernelMap()378   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
379     return *kKernelMapNoPie;
380   }
381 };
382 
383 const auto kKernelMapHybrid = new std::map<string, TfLiteRegistration*>({
384     {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
385     // Only Pie supports the hybrid path, so the optimized kernel should fall
386     // back to the Pie path in such cases.
387     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
388 });
389 
390 // Hybrid mode is used by the Pie quantized kernel.
391 class HybridFullyConnectedOpTest : public SingleOpTest {
392  protected:
GetKernelMap()393   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
394     return *kKernelMapHybrid;
395   }
396 };
397 
398 // TODO(ahentz): add more small tests like this one, focused on making sure the
399 // calculations are correct.
TEST_P(FloatFullyConnectedOpTest,SimpleTest)400 TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
401   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2,
402                                /*input=*/{TensorType_FLOAT32, {2, 10}});
403   m.SetWeights({
404       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
405       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
406       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
407   });
408   m.SetBias({1, 2, 3});
409 
410   m.SetInput({
411       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
412       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
413   });
414 
415   m.Invoke();
416 
417   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
418   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
419 }
420 
TEST_P(FloatFullyConnectedOpTest,SimpleTest2)421 TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
422   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2,
423                                /*input=*/{TensorType_FLOAT32, {2, 2}});
424   m.SetWeights({
425       2, 4,  // u = 0
426   });
427   m.SetBias({1});
428 
429   m.SetInput({
430       1, 2,  // b = 0
431       2, 1,  // b = 1
432   });
433 
434   m.Invoke();
435 
436   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
437   EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
438 }
439 
TEST(FloatFullyConnectedOpTest,SimpleTestNoBias)440 TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
441   // The optimized kernel assumes that the bias is specified.
442   FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
443                                /*units=*/1, /*batches=*/2,
444                                /*input=*/{TensorType_FLOAT32, {2, 2}},
445                                /*output=*/{TensorType_FLOAT32},
446                                /*keep_num_dims=*/false,
447                                /*bias_tensor_optional=*/true);
448   m.SetWeights({
449       2, 4,  // u = 0
450   });
451 
452   m.SetInput({
453       1, 2,  // b = 0
454       2, 1,  // b = 1
455   });
456 
457   m.Invoke();
458 
459   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
460   EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
461 }
462 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8)463 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
464   QuantizedFullyConnectedOpModel m(
465       GetRegistration(), /*units=*/3, /*batches*/ 2,
466       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
467       /*output=*/{TensorType_UINT8, {}, -127, 128});
468 
469   // input_product_scale < output_scale was not true.
470   m.SetWeights<uint8_t>({
471       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
472       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
473       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
474   });
475   m.SetBias({1, 2, 3});
476 
477   m.SetInput<uint8_t>({
478       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
479       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
480   });
481 
482   m.Invoke();
483 
484   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
485               ElementsAreArray(ArrayFloatNear({
486                   24, 25, 26,  //
487                   58, 59, 60,  //
488               })));
489   EXPECT_THAT(m.GetOutput<uint8_t>(),
490               ElementsAre(151, 152, 153, 185, 186, 187));
491 }
492 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8NoBias)493 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) {
494   QuantizedFullyConnectedOpModel m(
495       GetRegistration(), /*units=*/3, /*batches*/ 2,
496       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
497       /*output=*/{TensorType_UINT8, {}, -127, 128},
498       /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
499       /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
500       /*FullyConnectedOptionsWeightsFormat weights_format =*/
501       FullyConnectedOptionsWeightsFormat_DEFAULT,
502       /*add_bias_for_quantized =*/false);
503 
504   // input_product_scale < output_scale was not true.
505   m.SetWeights<uint8_t>({
506       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
507       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
508       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
509   });
510 
511   m.SetInput<uint8_t>({
512       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
513       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
514   });
515 
516   m.Invoke();
517 
518   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
519               ElementsAreArray(ArrayFloatNear({
520                   23, 23, 23,  //
521                   57, 57, 57,  //
522               })));
523   EXPECT_THAT(m.GetOutput<uint8_t>(),
524               ElementsAre(150, 150, 150, 184, 184, 184));
525 }
526 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8)527 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
528   QuantizedFullyConnectedOpModel m(
529       GetRegistration(), /*units=*/3, /*batches*/ 2,
530       /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
531       /*output=*/{TensorType_INT8, {}, -127, 128});
532 
533   // input_product_scale < output_scale was not true.
534   m.SetWeights<int8_t>({
535       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
536       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
537       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
538   });
539   m.SetBias({1, 2, 3});
540 
541   m.SetInput<int8_t>({
542       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
543       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
544   });
545 
546   m.Invoke();
547 
548   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
549               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
550   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
551 }
552 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16)553 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16) {
554   const float scale = 128.0 / 65536;
555   QuantizedFullyConnectedOpModel m(
556       GetRegistration(), /*units=*/3, /*batches*/ 2,
557       /*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
558       /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0});
559 
560   // input_product_scale < output_scale was not true.
561   m.SetWeights<int8_t>({
562       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
563       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
564       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
565   });
566   m.SetBias64({1, 2, 3});
567 
568   m.SetInput<int16_t>({
569       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
570       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
571   });
572 
573   m.Invoke();
574 
575   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
576               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
577   EXPECT_THAT(m.GetOutput<int16_t>(),
578               ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
579 }
580 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8NoBias)581 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8NoBias) {
582   QuantizedFullyConnectedOpModel m(
583       GetRegistration(), /*units=*/3, /*batches*/ 2,
584       /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
585       /*output=*/{TensorType_INT8, {}, -127, 128},
586       /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
587       /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
588       /*FullyConnectedOptionsWeightsFormat weights_format =*/
589       FullyConnectedOptionsWeightsFormat_DEFAULT,
590       /*add_bias_for_quantized =*/false);
591 
592   // input_product_scale < output_scale was not true.
593   m.SetWeights<int8_t>({
594       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
595       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
596       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
597   });
598 
599   m.SetInput<int8_t>({
600       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
601       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
602   });
603 
604   m.Invoke();
605 
606   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
607               ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
608   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
609 }
610 
611 // Test the GEMV path.
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestSingleBatchQuantizedInt8)612 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestSingleBatchQuantizedInt8) {
613   QuantizedFullyConnectedOpModel m(
614       GetRegistration(), /*units=*/4, /*batches*/ 1,
615       /*input=*/{TensorType_INT8, {1, 10}, -63.5, 64},
616       /*output=*/{TensorType_INT8, {}, -127, 128});
617 
618   // input_product_scale < output_scale was not true.
619   m.SetWeights<int8_t>({
620       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
621       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
622       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
623       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 3
624   });
625   m.SetBias({1, 2, 3, 4});
626 
627   m.SetInput<int8_t>({
628       1, 2, 3, 4, 5, 6, 7, -8, 9, -10  // b = 1
629   });
630 
631   m.Invoke();
632 
633   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
634               ElementsAreArray(ArrayFloatNear({58, 59, 60, 61})));
635   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(57, 58, 59, 60));
636 }
637 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8)638 TEST_P(QuantizedFullyConnectedOpTest,
639        SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8) {
640   // real_multiplier = 2.
641   QuantizedFullyConnectedOpModel m(
642       GetRegistration(), /*units=*/3, /*batches*/ 2,
643       /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
644       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
645 
646   m.SetWeights<uint8_t>({
647       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
648       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
649       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
650   });
651   m.SetBias({1, 2, 3});
652 
653   m.SetInput<uint8_t>({
654       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
655       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
656   });
657 
658   m.Invoke();
659 
660   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
661               ElementsAreArray(ArrayFloatNear({
662                   24, 25, 26,  // first batch
663                   58, 59, 60,  // second batch
664               })));
665   EXPECT_THAT(m.GetOutput<uint8_t>(),
666               ElementsAre(175, 177, 179, 243, 245, 247));
667 }
668 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Int8)669 TEST_P(QuantizedFullyConnectedOpTest,
670        SimpleTestQuantizedOutputMultiplierGreaterThan1Int8) {
671   // real_multiplier = 2.
672   QuantizedFullyConnectedOpModel m(
673       GetRegistration(), /*units=*/3, /*batches*/ 2,
674       /*input=*/{TensorType_INT8, {2, 10}, -127, 128},
675       /*output=*/{TensorType_INT8, {}, -63.5, 64});
676 
677   m.SetWeights<int8_t>({
678       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
679       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
680       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
681   });
682   m.SetBias({1, 2, 3});
683 
684   m.SetInput<int8_t>({
685       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
686       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
687   });
688 
689   m.Invoke();
690 
691   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
692               ElementsAreArray(ArrayFloatNear({
693                   24, 25, 26,  // first batch
694                   58, 59, 60,  // second batch
695               })));
696   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(47, 49, 51, 115, 117, 119));
697 }
698 
SimpleTestQuantizedInt16OutputCase(TfLiteRegistration * registration,int input_depth,int output_depth,int batches,FullyConnectedOptionsWeightsFormat weights_format)699 void SimpleTestQuantizedInt16OutputCase(
700     TfLiteRegistration* registration, int input_depth, int output_depth,
701     int batches, FullyConnectedOptionsWeightsFormat weights_format) {
702   const uint8_t kWeightsZeroPoint = 128;
703   const float kWeightsScale = 1.f / 128.f;
704   const uint8_t kInputZeroPoint = 128;
705   const float kInputScale = 1.f / 128.f;
706   const float kInputMin = (0 - kInputZeroPoint) * kInputScale;
707   const float kInputMax = (255 - kInputZeroPoint) * kInputScale;
708   // Output ranges in [-8..8] encoded as int16
709   const float kOutputScale = 8.f / 32768.f;
710   const float kOutputMin = -32768 * kOutputScale;
711   const float kOutputMax = 32767 * kOutputScale;
712 
713   QuantizedFullyConnectedOpModel m(
714       registration, output_depth, batches,
715       /*input=*/
716       {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
717       /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
718       /*keep_num_dims=*/false,
719       /*bias_tensor_optional=*/false,
720       /*activation_func=*/ActivationFunctionType_NONE, weights_format);
721 
722   std::mt19937 random_engine;
723   // Some compilers don't support uint8_t for uniform_distribution.
724   std::uniform_int_distribution<uint32_t> weights_dist(
725       0, std::numeric_limits<uint8_t>::max());
726 
727   std::vector<float> weights_data(input_depth * output_depth);
728   for (auto& w : weights_data) {
729     uint8_t q = static_cast<uint8_t>(weights_dist(random_engine));
730     w = (q - kWeightsZeroPoint) * kWeightsScale;
731   }
732 
733   // Based on weights_format, enforce any shape requirement for that format/path
734   // and set the (possibly shuffled) weights.
735   switch (weights_format) {
736     case FullyConnectedOptionsWeightsFormat_DEFAULT:
737       m.SetWeights<uint8_t>(weights_data);
738       break;
739     case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
740       // The shuffled path currently supports only a restrictive subset of
741       // shapes, described by the following assertions:
742       CHECK_EQ(input_depth % 16, 0);
743       CHECK_EQ(output_depth % 4, 0);
744       CHECK(batches == 1 || batches == 4);
745       m.ShuffleAndSetWeights<uint8_t>(weights_data, input_depth, output_depth);
746       break;
747     default:
748       LOG(FATAL) << "Unhandled weights format";
749   }
750 
751   // Some compilers don't support uint8_t for uniform_distribution.
752   std::uniform_int_distribution<uint32_t> input_dist(
753       0, std::numeric_limits<uint8_t>::max());
754   std::vector<float> input_data(input_depth * batches);
755   for (auto& i : input_data) {
756     uint8_t q = static_cast<uint8_t>(input_dist(random_engine));
757     i = (q - kInputZeroPoint) * kInputScale;
758   }
759 
760   std::vector<float> bias_data(output_depth);
761   // As the output ranges in [-8, 8], it's reasonable to have bias values
762   // in [-1, 1], this won't result in too much saturation.
763   std::uniform_real_distribution<float> bias_dist(-1.f, 1.f);
764   for (auto& b : bias_data) {
765     b = bias_dist(random_engine);
766   }
767 
768   m.SetBias(bias_data);
769   m.SetInput<uint8_t>(input_data);
770 
771   m.Invoke();
772 
773   std::vector<float> expected_output_data(output_depth * batches);
774   for (int b = 0; b < batches; b++) {
775     for (int o = 0; o < output_depth; o++) {
776       float accum = bias_data[o];
777       for (int i = 0; i < input_depth; i++) {
778         accum +=
779             input_data[b * input_depth + i] * weights_data[o * input_depth + i];
780       }
781       accum = std::min(accum, kOutputMax);
782       accum = std::max(accum, kOutputMin);
783       expected_output_data[b * output_depth + o] = accum;
784     }
785   }
786 
787   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
788               ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f)));
789 }
790 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputDefaultWeights)791 TEST_P(QuantizedFullyConnectedOpTest,
792        SimpleTestQuantizedInt16OutputDefaultWeights) {
793   for (int input_depth : {1, 3, 10, 100}) {
794     for (int output_depth : {1, 3, 10, 100}) {
795       for (int batch : {1, 3, 10, 100}) {
796         SimpleTestQuantizedInt16OutputCase(
797             GetRegistration(), input_depth, output_depth, batch,
798             FullyConnectedOptionsWeightsFormat_DEFAULT);
799       }
800     }
801   }
802 }
803 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights)804 TEST_P(QuantizedFullyConnectedOpTest,
805        SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) {
806   // The shuffled weights block shape is 4x16. The shape of the weights matrix
807   // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16.
808   // This means that output_depth must be a multiple of 4, and input_depth must
809   // be a multiple of 16.
810   for (int input_depth_numblocks : {1, 3}) {
811     for (int output_depth_numblocks : {1, 3}) {
812       int input_depth = 16 * input_depth_numblocks;
813       int output_depth = 4 * output_depth_numblocks;
814       // The fast shuffled path is currently supporting only batch sizes of 1
815       // and 4. The idea is that the whole point of that path is to go as fast
816       // as possible for small batch size, which requires fully specializing
817       // it for each batch size, and for larger batch sizes the generic
818       // gemmlowp-based implementation is fast enough.
819       for (int batch : {1, 4}) {
820         SimpleTestQuantizedInt16OutputCase(
821             GetRegistration(), input_depth, output_depth, batch,
822             FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
823       }
824     }
825   }
826 }
827 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedUint8)828 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedUint8) {
829   HybridFullyConnectedOpModel m(
830       /*units=*/3, /*batches=*/2,
831       /*input=*/{TensorType_FLOAT32, {2, 10}},
832       /*weights=*/
833       {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0});  // Hybrid
834 
835   m.SetWeights({
836       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
837       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
838       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
839   });
840   m.SetBias({1, 2, 3});
841 
842   m.SetInput({
843       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
844       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
845   });
846 
847   m.Invoke();
848 
849   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
850                                  {
851                                      24, 25, 26,  //
852                                      58, 59, 60,  //
853                                  },
854                                  /*max_abs_error=*/1.3f)));
855 }
856 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8)857 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
858   HybridFullyConnectedOpModel m(
859       /*units=*/3, /*batches=*/2,
860       /*input=*/{TensorType_FLOAT32, {2, 10}},
861       /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0});  // Hybrid
862 
863   m.SetSignedWeights({
864       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
865       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
866       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
867   });
868   m.SetBias({1, 2, 3});
869 
870   m.SetInput({
871       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
872       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
873   });
874 
875   m.Invoke();
876 
877   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
878                                  {
879                                      24, 25, 26,  //
880                                      58, 59, 60,  //
881                                  },
882                                  /*max_abs_error=*/1.3f)));
883 }
884 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8MultiThreaded)885 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) {
886   for (int num_threads = 1; num_threads <= 4; ++num_threads) {
887     HybridFullyConnectedOpModel m(
888         /*units=*/3, /*batches=*/4,
889         /*input=*/{TensorType_FLOAT32, {4, 10}},
890         /*weights=*/
891         {TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
892         /*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false,
893         /*num_threads=*/num_threads);  // Hybrid
894 
895     m.SetSignedWeights({
896         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
897         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
898         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
899     });
900     m.SetBias({1, 2, 3});
901 
902     m.SetInput({
903         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
904         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
905         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 2
906         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 3
907     });
908 
909     m.Invoke();
910 
911     EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3));
912     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
913                                    {
914                                        24, 25, 26,  //
915                                        58, 59, 60,  //
916                                        24, 25, 26,  //
917                                        58, 59, 60,  //
918                                    },
919                                    /*max_abs_error=*/1.3f)));
920   }
921 }
922 
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedUint8)923 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
924   HybridFullyConnectedOpModel m(
925       /*units=*/3, /*batches=*/2,
926       /*input=*/{TensorType_FLOAT32, {2, 10}},
927       /*weights=*/
928       {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
929       /*asymmetric_quantize_input*/ true);  // Hybrid asymmetric
930 
931   m.SetWeights({
932       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
933       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
934       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
935   });
936   m.SetBias({1, 2, 3});
937 
938   m.SetInput({
939       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
940       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
941   });
942 
943   m.Invoke();
944 
945   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
946                                  {
947                                      24, 25, 26,  //
948                                      58, 59, 60,  //
949                                  },
950                                  /*max_abs_error=*/0.64f)));
951 }
952 
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedInt8)953 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
954   HybridFullyConnectedOpModel m(
955       /*units=*/3, /*batches=*/2,
956       /*input=*/{TensorType_FLOAT32, {2, 10}},
957       /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
958       {TensorType_FLOAT32},
959       /*asymmetric_quantize_input*/ true);
960 
961   m.SetSignedWeights({
962       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
963       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
964       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
965   });
966   m.SetBias({1, 2, 3});
967 
968   m.SetInput({
969       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
970       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
971   });
972 
973   m.Invoke();
974 
975   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
976                                  {
977                                      24, 25, 26,  //
978                                      58, 59, 60,  //
979                                  },
980                                  /*max_abs_error=*/1.3f)));
981 }
982 
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput)983 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
984   // Note that it is not required that the first dimension be the number of
985   // batches. All we care is that the input can be evenly distributed in
986   // batches. In this case, we need the input to have multiples of '2'.
987   FloatFullyConnectedOpModel m(GetRegistration(),
988                                /*units=*/3, /*batches=*/2,
989                                /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
990   m.SetWeights({
991       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
992       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
993       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
994   });
995   m.SetBias({1, 2, 3});
996 
997   m.SetInput({
998       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
999       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
1000   });
1001 
1002   m.Invoke();
1003 
1004   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1005   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1006                                  24, 25, 26,  // first batch
1007                                  58, 59, 60,  // second batch
1008                              }));
1009 }
1010 
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput4DOutput)1011 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput4DOutput) {
1012   // Note that it is not required that the first dimension be the number of
1013   // batches. All we care is that the input can be evenly distributed in
1014   // batches. In this case, we need the input to have multiples of '2'.
1015   FloatFullyConnectedOpModel m(GetRegistration(),
1016                                /*units=*/3, /*batches=*/2,
1017                                /*input=*/{TensorType_FLOAT32, {1, 2, 1, 10}},
1018                                /*output=*/{TensorType_FLOAT32},
1019                                /*keep_num_dims=*/true);
1020   m.SetWeights({
1021       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1022       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1023       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1024   });
1025   m.SetBias({1, 2, 3});
1026 
1027   m.SetInput({
1028       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
1029       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
1030   });
1031 
1032   m.Invoke();
1033 
1034   EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 1, 3));
1035   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1036                                  24, 25, 26,  // first batch
1037                                  58, 59, 60,  // second batch
1038                              }));
1039 }
1040 
1041 #ifdef GTEST_HAS_DEATH_TEST
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInputInvalidShape)1042 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInputInvalidShape) {
1043   // Note that it is not required that the first dimension be the number of
1044   // batches. But it is required that the last dimension is the 'input_dim'.
1045   //
1046   // For this particular test, it is required for the output to be reformattable
1047   // into a shape of form {4, 1, 5, ?} but since the output size (the product of
1048   // output dimensions: units times batches) is 6, this is not possible.
1049   EXPECT_DEATH(FloatFullyConnectedOpModel m(
1050                    GetRegistration(), /*units=*/3, /*batches=*/2,
1051                    /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}},
1052                    /*output=*/{TensorType_FLOAT32},
1053                    /*keep_num_dims=*/true),
1054                "Cannot allocate tensors");
1055 }
1056 #endif
1057 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedUint8)1058 TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantizedUint8) {
1059   QuantizedFullyConnectedOpModel m(
1060       GetRegistration(), /*units=*/3, /*batches=*/2,
1061       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
1062       /*output=*/{TensorType_UINT8, {}, -127, 128});
1063 
1064   // input_product_scale < output_scale was not true.
1065   m.SetWeights<uint8_t>({
1066       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1067       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1068       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1069   });
1070   m.SetBias({1, 2, 3});
1071 
1072   m.SetInput<uint8_t>({
1073       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1074       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1075   });
1076 
1077   m.Invoke();
1078 
1079   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1080               ElementsAreArray(ArrayFloatNear({
1081                   24, 25, 26,  //
1082                   58, 59, 60,  //
1083               })));
1084   EXPECT_THAT(m.GetOutput<uint8_t>(),
1085               ElementsAre(151, 152, 153, 185, 186, 187));
1086 }
1087 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8)1088 TEST_P(QuantizedFullyConnectedOpTest,
1089        SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8) {
1090   // real_multiplier = 2.
1091   QuantizedFullyConnectedOpModel m(
1092       GetRegistration(), /*units=*/3, /*batches=*/2,
1093       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128},
1094       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
1095 
1096   m.SetWeights<uint8_t>({
1097       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1098       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1099       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1100   });
1101   m.SetBias({1, 2, 3});
1102 
1103   m.SetInput<uint8_t>({
1104       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1105       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1106   });
1107 
1108   m.Invoke();
1109 
1110   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1111               ElementsAreArray(ArrayFloatNear({
1112                   24, 25, 26,  // first batch
1113                   58, 59, 60,  // second batch
1114               })));
1115   EXPECT_THAT(m.GetOutput<uint8_t>(),
1116               ElementsAre(175, 177, 179, 243, 245, 247));
1117 }
1118 
1119 INSTANTIATE_TEST_SUITE_P(
1120     FloatFullyConnectedOpTest, FloatFullyConnectedOpTest,
1121     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
1122 
1123 INSTANTIATE_TEST_SUITE_P(
1124     QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest,
1125     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
1126 
1127 // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
1128 // to debug errors and doesn't necessarily test all the important details.
TEST_P(FloatFullyConnectedOpTest,BlackBoxTest)1129 TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
1130   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2,
1131                                /*input=*/{TensorType_FLOAT32, {2, 8}});
1132   m.SetWeights(
1133       {0.091327,  0.103366,  -0.316505, -0.083120, 0.149366,  -0.196636,
1134        -0.123672, 0.062800,  0.063031,  0.191670,  -0.062001, -0.061504,
1135        -0.275581, 0.059388,  -0.118497, -0.079224, 0.109758,  0.008307,
1136        -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
1137        0.266455,  0.051517,  -0.123448, 0.322464,  0.043282,  -0.173782,
1138        -0.190381, 0.002013,  0.096086,  0.131157,  0.031164,  0.100638,
1139        -0.312191, -0.080923, -0.101318, -0.116614, 0.142238,  0.086540,
1140        -0.139154, 0.174268,  -0.073161, 0.080072,  0.006874,  0.229382,
1141        -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
1142        -0.025021, 0.074460,  -0.252595, -0.161750, -0.136403, 0.008308,
1143        0.005710,  0.096600,  0.289839,  0.218816,  -0.304651, -0.070958,
1144        0.054598,  0.147113,  -0.139112, -0.072798, -0.163335, -0.167863,
1145        -0.128762, -0.035780, 0.117262,  0.017177,  0.263335,  -0.176612,
1146        0.262961,  -0.093654, -0.339283, 0.333071,  0.180827,  0.287583,
1147        0.066350,  -0.197947, -0.114449, -0.236035, 0.103532,  -0.034284,
1148        0.093299,  -0.145361, 0.054001,  0.250570,  0.157010,  -0.143480,
1149        -0.139061, -0.048873, 0.067557,  0.139038,  0.324106,  0.227041,
1150        0.037793,  -0.225747, -0.241619, 0.357835,  0.135762,  -0.306764,
1151        -0.125982, 0.091916,  0.266587,  0.030135,  0.265148,  0.141627,
1152        0.020120,  0.083815,  -0.124556, -0.100124, -0.048159, 0.181172,
1153        0.302309,  -0.041084, 0.146334,  -0.061511, -0.232605, 0.281324,
1154        0.145408,  -0.221897});
1155   m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
1156              0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
1157              0.048478, -0.032270, 0.175688, -0.085662});
1158 
1159   const int input_sequence_size = sizeof(fully_connected_input) /
1160                                   sizeof(float) /
1161                                   (m.input_size() * m.num_batches());
1162   for (int i = 0; i < input_sequence_size; i++) {
1163     // TODO(ahentz): This is what the original test was doing: two equal
1164     // batches per invocation. We could instead use two different batches.
1165     float* batch_start = fully_connected_input + i * m.input_size();
1166     float* batch_end = batch_start + m.input_size();
1167     m.SetInput(0, batch_start, batch_end);
1168     m.SetInput(m.input_size(), batch_start, batch_end);
1169 
1170     m.Invoke();
1171 
1172     float* golden_start = fully_connected_golden_output + i * m.num_units();
1173     float* golden_end = golden_start + m.num_units();
1174     std::vector<float> expected;
1175     expected.insert(expected.end(), golden_start, golden_end);
1176     expected.insert(expected.end(), golden_start, golden_end);
1177 
1178     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1179   }
1180 }
1181 
1182 template <typename T>
1183 class SparseFullyConnectedOpModel : public SingleOpModel {
1184  public:
SparseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & weights,const std::vector<T> & weights_data,int num_threads=1,bool symmetric_quantize_weights=false)1185   SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
1186                               int batches, const TensorData& input,
1187                               const TensorData& weights,
1188                               const std::vector<T>& weights_data,
1189                               int num_threads = 1,
1190                               bool symmetric_quantize_weights = false)
1191       : batches_(batches), units_(units) {
1192     int total_input_size = 1;
1193     for (size_t i = 0; i < input.shape.size(); ++i) {
1194       total_input_size *= input.shape[i];
1195     }
1196     input_size_ = total_input_size / batches_;
1197 
1198     input_ = AddInput(input);
1199     weights_ =
1200         AddConstSparseInput(weights, weights_data, symmetric_quantize_weights);
1201 
1202     TensorData bias{input.type, {units_}};
1203     bias_ = AddInput(bias);
1204 
1205     output_ = AddOutput({input.type});
1206 
1207     SetBuiltinOp(
1208         BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
1209         CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
1210             .Union());
1211     resolver_ = absl::make_unique<SingleOpResolver>(
1212         BuiltinOperator_FULLY_CONNECTED, registration);
1213     BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
1214                      num_threads, /*allow_fp32_relax_to_fp16=*/false,
1215                      /*apply_delegate=*/false);
1216   }
SetBias(const std::vector<T> & data)1217   void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
SetInput(const std::vector<T> & data)1218   void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
GetOutput()1219   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()1220   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1221 
input_size()1222   int input_size() { return input_size_; }
num_units()1223   int num_units() { return units_; }
num_batches()1224   int num_batches() { return batches_; }
1225 
1226  protected:
1227   int input_;
1228   int weights_;
1229   int bias_;
1230   int output_;
1231 
1232   int batches_;
1233   int units_;
1234   int input_size_;
1235 };
1236 
1237 class SparseFullyConnectedOpTest : public SingleOpTest {
1238  protected:
GetKernelMap()1239   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
1240     return *kKernelMapNoPie;
1241   }
1242 };
1243 
TEST_P(SparseFullyConnectedOpTest,SimpleTest)1244 TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
1245   std::initializer_list<float> weight_data = {
1246       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1247       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1248       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1249   };
1250   TensorData weight = {};
1251   weight.type = TensorType_FLOAT32;
1252   weight.shape = {3, 10};
1253   weight.traversal_order = {0, 1};
1254   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1255   SparseFullyConnectedOpModel<float> m(
1256       GetRegistration(), /*units=*/3, /*batches=*/2,
1257       /*input=*/{TensorType_FLOAT32, {2, 10}}, weight, weight_data);
1258   m.SetBias({1, 2, 3});
1259 
1260   m.SetInput({
1261       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1262       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1263   });
1264 
1265   m.Invoke();
1266 
1267   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1268   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
1269 }
1270 
TEST_P(SparseFullyConnectedOpTest,SimpleTest2)1271 TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
1272   std::initializer_list<float> weight_data = {
1273       2, 4  // u = 0
1274   };
1275   TensorData weight = {};
1276   weight.type = TensorType_FLOAT32;
1277   weight.shape = {1, 2};
1278   weight.traversal_order = {0, 1};
1279   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1280   SparseFullyConnectedOpModel<float> m(
1281       GetRegistration(), /*units=*/1, /*batches=*/2,
1282       /*input=*/{TensorType_FLOAT32, {2, 2}}, weight, weight_data);
1283   m.SetBias({1});
1284 
1285   m.SetInput({
1286       1, 2,  // b = 0
1287       2, 1   // b = 1
1288   });
1289 
1290   m.Invoke();
1291 
1292   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
1293   EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
1294 }
1295 
TEST_P(SparseFullyConnectedOpTest,Simple1x4Test)1296 TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) {
1297   std::initializer_list<float> weight_data = {
1298       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1299       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1300       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1301   };
1302   TensorData weight = {};
1303   weight.type = TensorType_FLOAT32;
1304   weight.shape = {3, 12};
1305   weight.traversal_order = {0, 1, 2};
1306   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1307   weight.block_map = {1};
1308   weight.block_size = {4};
1309   SparseFullyConnectedOpModel<float> m(GetRegistration(),
1310                                        /*units=*/3, /*batches=*/2,
1311                                        /*input=*/{TensorType_FLOAT32, {2, 12}},
1312                                        weight, weight_data);
1313   m.SetBias({1, 2, 3});
1314 
1315   m.SetInput({
1316       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1317       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1318   });
1319 
1320   m.Invoke();
1321 
1322   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1323   EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1324 }
1325 
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreaded)1326 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreaded) {
1327   std::initializer_list<float> weight_data = {
1328       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1329       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1330       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1331   };
1332   TensorData weight = {};
1333   weight.type = TensorType_FLOAT32;
1334   weight.shape = {3, 12};
1335   weight.traversal_order = {0, 1, 2};
1336   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1337   weight.block_map = {1};
1338   weight.block_size = {4};
1339   for (int num_threads = 1; num_threads <= 4; num_threads++) {
1340     SparseFullyConnectedOpModel<float> m(
1341         GetRegistration(),
1342         /*units=*/3, /*batches=*/2,
1343         /*input=*/{TensorType_FLOAT32, {2, 12}}, weight, weight_data,
1344         num_threads);
1345     m.SetBias({1, 2, 3});
1346 
1347     m.SetInput({
1348         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1349         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1350     });
1351 
1352     m.Invoke();
1353 
1354     EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1355     EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1356   }
1357 }
1358 
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreadedMoreBatches)1359 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreadedMoreBatches) {
1360   std::initializer_list<float> weight_data = {
1361       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1362       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1363       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1364   };
1365   TensorData weight = {};
1366   weight.type = TensorType_FLOAT32;
1367   weight.shape = {3, 12};
1368   weight.traversal_order = {0, 1, 2};
1369   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1370   weight.block_map = {1};
1371   weight.block_size = {4};
1372   for (int num_threads = 1; num_threads <= 4; num_threads++) {
1373     SparseFullyConnectedOpModel<float> m(
1374         GetRegistration(),
1375         /*units=*/3, /*batches=*/6,
1376         /*input=*/{TensorType_FLOAT32, {6, 12}}, weight, weight_data,
1377         num_threads);
1378     m.SetBias({1, 2, 3});
1379 
1380     m.SetInput({
1381         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1382         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1383         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 2
1384         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 3
1385         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 4
1386         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 5
1387     });
1388 
1389     m.Invoke();
1390 
1391     EXPECT_THAT(m.GetOutputShape(), ElementsAre(6, 3));
1392     EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291,  // b = 0
1393                                            81, 82, 83,     // b = 1
1394                                            289, 290, 291,  // b = 2
1395                                            81, 82, 83,     // b = 3
1396                                            289, 290, 291,  // b = 4
1397                                            81, 82, 83      // b = 5
1398                                            ));
1399   }
1400 }
1401 
TEST_P(SparseFullyConnectedOpTest,SparseHybrid1x16Test)1402 TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) {
1403   std::initializer_list<float> weight_data = {
1404       /* 1st row */
1405       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1406       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1407       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1408       10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1409       /* 2nd row */
1410       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1411       0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1412       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1413       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1414       /* 3rd row */
1415       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1416       0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1417       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1418       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1419       /* 4th row */
1420       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1421       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1422       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1423       8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1424   TensorData weight = {};
1425   weight.type = TensorType_FLOAT32;
1426   weight.shape = {4, 48};
1427   weight.traversal_order = {0, 1, 2};
1428   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1429   weight.block_map = {1};
1430   weight.block_size = {16};
1431   SparseFullyConnectedOpModel<float> m(
1432       GetRegistration(),
1433       /*units=*/4, /*batches=*/2,
1434       /*input=*/{TensorType_FLOAT32, {2, 48}}, weight, weight_data,
1435       /*num_threads)=*/1, /*symmetric_quantize_weights=*/true);
1436   m.SetBias({1, 2, 3, 4});
1437   m.SetInput({
1438       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1439       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1440       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1441       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1442       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 0
1443       2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1444       -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1445       0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1446       -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1447       1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 1
1448   });
1449 
1450   m.Invoke();
1451 
1452   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4));
1453   EXPECT_THAT(m.GetOutput(),
1454               ElementsAreArray(ArrayFloatNear(
1455                   {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
1456 }
1457 
TEST_P(SparseFullyConnectedOpTest,SparseHybrid1x16TestMultiThreaded)1458 TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) {
1459   std::initializer_list<float> weight_data = {
1460       /* 1st row */
1461       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1462       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1463       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1464       10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1465       /* 2nd row */
1466       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1467       0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1468       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1469       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1470       /* 3rd row */
1471       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1472       0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1473       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1474       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1475       /* 4th row */
1476       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1477       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1478       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1479       8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1480   TensorData weight = {};
1481   weight.type = TensorType_FLOAT32;
1482   weight.shape = {4, 48};
1483   weight.traversal_order = {0, 1, 2};
1484   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1485   weight.block_map = {1};
1486   weight.block_size = {16};
1487   for (int num_threads = 1; num_threads <= 4; ++num_threads) {
1488     SparseFullyConnectedOpModel<float> m(
1489         GetRegistration(),
1490         /*units=*/4, /*batches=*/4,
1491         /*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data,
1492         /*num_threads)=*/num_threads, /*symmetric_quantize_weights=*/true);
1493     m.SetBias({1, 2, 3, 4});
1494     m.SetInput({
1495         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1496         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1497         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1498         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1499         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 0
1500         2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1501         -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1502         0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1503         -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1504         1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 1
1505         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1506         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1507         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1508         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1509         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 2
1510         2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1511         -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1512         0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1513         -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1514         1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 3
1515     });
1516 
1517     m.Invoke();
1518 
1519     EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4));
1520     EXPECT_THAT(m.GetOutput(),
1521                 ElementsAreArray(ArrayFloatNear(
1522                     {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0, 0,
1523                      7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0},
1524                     1e-3)));
1525   }
1526 }
1527 // TODO(b/148391360): Add tests for unsupported sparsity format.
1528 // TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
1529 
1530 INSTANTIATE_TEST_SUITE_P(
1531     SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
1532     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
1533 
1534 }  // namespace
1535 }  // namespace tflite
1536