1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite FULLY_CONNECTED op.
16
17 #include "tensorflow/lite/kernels/fully_connected.h"
18
19 #include <stddef.h>
20 #include <stdint.h>
21
22 #include <algorithm>
23 #include <initializer_list>
24 #include <limits>
25 #include <map>
26 #include <memory>
27 #include <random>
28 #include <string>
29 #include <vector>
30
31 #include <gmock/gmock.h>
32 #include <gtest/gtest.h>
33 #include "absl/memory/memory.h"
34 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
35 #include "tensorflow/lite/core/api/op_resolver.h"
36 #include "tensorflow/lite/interpreter.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/test_util.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/string_type.h"
41
42 namespace tflite {
43 namespace {
44
45 using ::testing::ElementsAre;
46 using ::testing::ElementsAreArray;
47
48 static float fully_connected_input[] = {
49 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
50 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
51 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
52 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
53 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
54 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
55 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
56 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
57 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
58 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
59 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
60 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
61 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
62 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
63 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
64 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
65 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
66 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
67 0.921330, 0.139902};
68
69 static float fully_connected_golden_output[] = {
70 0, 0.0732134, 0, 0, 0, 0.280859,
71 0, 0.128927, 0, 0.0777251, 0, 0.270268,
72 0.271435, 0.0173503, 0.335465, 0.235562,
73
74 0, 0.0745866, 0, 0.051611, 0, 0.253876,
75 0, 0.0814873, 0, 0.104104, 0, 0.248529,
76 0.264194, 0, 0.302973, 0.166252,
77
78 0, 0.0170409, 0, 0.0509851, 0, 0.212834,
79 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428,
80 0.298051, 0, 0.332233, 0.00445903,
81
82 0, 0.125246, 0, 0.0735336, 0, 0.0910256,
83 0, 0, 0, 0.18933, 0.378111, 0.0712443,
84 0.277298, 0.0123414, 0.267454, 0,
85
86 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256,
87 0, 0, 0, 0.156412, 0.434914, 0.0461529,
88 0.246508, 0, 0.363138, 0,
89
90 0, 0, 0, 0.0212949, 0, 0.301708,
91 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195,
92 0.197161, 0, 0.37316, 0,
93
94 0, 0.221783, 0, 0, 0.0116515, 0.281945,
95 0, 0, 0, 0, 0.285626, 0.181773,
96 0.296401, 0.170452, 0.367135, 0.142597,
97
98 0, 0, 0, 0, 0, 0.418886,
99 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589,
100 0.398286, 0.177146, 0.40359, 0.121452,
101
102 0, 0.0834884, 0, 0, 0, 0.287441,
103 0, 0.0046838, 0, 0.0122087, 0, 0.217376,
104 0.140183, 0.0948412, 0.436677, 0.0589876,
105
106 0, 0.0289969, 0, 0.0921397, 0, 0.396802,
107 0, 0.0126157, 0, 0.0968433, 0, 0.172271,
108 0.173295, 0.0664741, 0.53645, 0.00915603,
109
110 0, 0, 0, 0, 0, 0.147942,
111 0, 0.263795, 0, 0.39782, 0, 0.382435,
112 0.561072, 0.0579847, 0.145712, 0.13508,
113
114 0, 0, 0, 0.16382, 0, 0.322294,
115 0, 0.163798, 0, 0.405211, 0.367953, 0.076852,
116 0.342473, 0.0834118, 0.377537, 0,
117
118 0, 0.206, 0, 0, 0, 0.375769,
119 0, 0, 0, 0, 0, 0.125165,
120 0, 0.105591, 0.52055, 0.0536445,
121
122 0, 0.259261, 0, 0, 0, 0.247707,
123 0, 0, 0, 0, 0, 0.215862,
124 0.149153, 0.224678, 0.359519, 0.129419,
125
126 0, 0.17611, 0, 0.280895, 0, 0.576484,
127 0, 0.000418848, 0, 0, 0, 0.151112,
128 0.211902, 0, 0.566341, 0.106305,
129
130 0, 0.0246284, 0, 0, 0, 0.196267,
131 0, 0.0248624, 0, 0.265635, 0, 0.436199,
132 0.408079, 0.134514, 0.328489, 0.411368};
133
134 class BaseFullyConnectedOpModel : public SingleOpModel {
135 public:
136 // TODO(ahentz): test different activation types too.
BaseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & output={TensorType_FLOAT32},bool keep_num_dims=false,bool bias_tensor_optional=false,ActivationFunctionType activation_func=ActivationFunctionType_RELU,FullyConnectedOptionsWeightsFormat weights_format=FullyConnectedOptionsWeightsFormat_DEFAULT,bool add_bias_for_quantized=true)137 BaseFullyConnectedOpModel(
138 TfLiteRegistration* registration, int units, int batches,
139 const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
140 bool keep_num_dims = false, bool bias_tensor_optional = false,
141 ActivationFunctionType activation_func = ActivationFunctionType_RELU,
142 FullyConnectedOptionsWeightsFormat weights_format =
143 FullyConnectedOptionsWeightsFormat_DEFAULT,
144 bool add_bias_for_quantized = true)
145 : batches_(batches), units_(units) {
146 int total_input_size = 1;
147 for (size_t i = 0; i < input.shape.size(); ++i) {
148 total_input_size *= input.shape[i];
149 }
150 input_size_ = total_input_size / batches_;
151
152 input_ = AddInput(input);
153 if (input.type == TensorType_INT16) {
154 weights_ = AddInput({TensorType_INT8, {units_, input_size_}, -63.5, 64});
155 } else {
156 weights_ =
157 AddInput({input.type, {units_, input_size_}, input.min, input.max});
158 }
159
160 if (bias_tensor_optional) {
161 bias_ = AddNullInput();
162 } else if (input.type == TensorType_FLOAT32) {
163 bias_ = AddInput({TensorType_FLOAT32, {units_}});
164 } else if (add_bias_for_quantized) {
165 // This is a quantized version. The scale of 'bias' depends on the scales
166 // of input and filter. Supposedly this is correctly set during quantized
167 // training.
168 auto bias_scale = GetScale(input_) * GetScale(weights_);
169 if (input.type == TensorType_INT16) {
170 TensorData bias{TensorType_INT64, {units_}, 0, 0, bias_scale};
171 bias_ = AddInput(bias);
172 } else {
173 TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
174 bias_ = AddInput(bias);
175 }
176 }
177
178 output_ = AddOutput(output);
179 if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) {
180 AddOutput({TensorType_UINT8, input.shape});
181 }
182
183 SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
184 BuiltinOptions_FullyConnectedOptions,
185 CreateFullyConnectedOptions(builder_, activation_func,
186 weights_format, keep_num_dims)
187 .Union());
188 resolver_ = absl::make_unique<SingleOpResolver>(
189 BuiltinOperator_FULLY_CONNECTED, registration);
190 std::vector<std::vector<int>> inputs = {GetShape(input_),
191 GetShape(weights_)};
192 if (add_bias_for_quantized) {
193 inputs.push_back((bias_ == kTfLiteOptionalTensor) ? std::vector<int>()
194 : GetShape(bias_));
195 }
196 BuildInterpreter(inputs);
197 }
198
input_size()199 int input_size() { return input_size_; }
num_units()200 int num_units() { return units_; }
num_batches()201 int num_batches() { return batches_; }
202
203 protected:
204 int input_;
205 int weights_;
206 int bias_;
207 int output_;
208
209 int batches_;
210 int units_;
211 int input_size_;
212 };
213
214 class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
215 public:
216 using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
217
SetBias(const std::vector<float> & f)218 void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
219
SetWeights(const std::vector<float> & f)220 void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
221
SetInput(const std::vector<float> & data)222 void SetInput(const std::vector<float>& data) {
223 PopulateTensor(input_, data);
224 }
SetInput(int offset,float * begin,float * end)225 void SetInput(int offset, float* begin, float* end) {
226 PopulateTensor(input_, offset, begin, end);
227 }
228
GetOutput()229 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()230 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
231 };
232
233 class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
234 public:
235 using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
236
SetBias(const std::vector<float> & data)237 void SetBias(const std::vector<float>& data) {
238 QuantizeAndPopulate<int32_t>(bias_, data);
239 }
SetBias64(const std::vector<float> & data)240 void SetBias64(const std::vector<float>& data) {
241 QuantizeAndPopulate<int64_t>(bias_, data);
242 }
243 template <typename T>
SetWeights(const std::vector<float> & data)244 void SetWeights(const std::vector<float>& data) {
245 QuantizeAndPopulate<T>(weights_, data);
246 }
247
248 template <typename T>
ShuffleAndSetWeights(const std::vector<float> & data,int input_depth,int output_depth)249 void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
250 int output_depth) {
251 std::vector<float> shuffled_data(data.size());
252 CHECK_EQ(input_depth % 16, 0);
253 CHECK_EQ(output_depth % 4, 0);
254 float* shuffled_data_ptr = shuffled_data.data();
255 for (int block_o = 0; block_o < output_depth; block_o += 4) {
256 for (int block_i = 0; block_i < input_depth; block_i += 16) {
257 for (int o = 0; o < 4; o++) {
258 for (int i = 0; i < 16; i++) {
259 *shuffled_data_ptr++ =
260 data[(block_o + o) * input_depth + block_i + i];
261 }
262 }
263 }
264 }
265 TfLiteTensor* t = interpreter_->tensor(weights_);
266 auto quantized_data =
267 Quantize<T>(shuffled_data, t->params.scale, t->params.zero_point);
268 for (T& q : quantized_data) {
269 q ^= 0x80;
270 }
271 PopulateTensor(weights_, 0, quantized_data.data(),
272 quantized_data.data() + quantized_data.size());
273 }
274
275 template <typename T>
SetInput(const std::vector<float> & data)276 void SetInput(const std::vector<float>& data) {
277 QuantizeAndPopulate<T>(input_, data);
278 }
279
280 template <typename T>
GetOutput()281 std::vector<T> GetOutput() {
282 return ExtractVector<T>(output_);
283 }
284
285 template <typename T>
GetDequantizedOutput()286 std::vector<float> GetDequantizedOutput() {
287 return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
288 GetZeroPoint(output_));
289 }
290 };
291
292 // In the hybrid model the weights are quantized (to uint8). But the bias,
293 // input (and output) are expected to be in float precision.
294 class HybridFullyConnectedOpModel : public SingleOpModel {
295 public:
HybridFullyConnectedOpModel(int units,int batches,const TensorData & input,const TensorData & weights,const TensorData & output={TensorType_FLOAT32},bool asymmetric_inputs=false,int num_threads=1)296 HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
297 const TensorData& weights,
298 const TensorData& output = {TensorType_FLOAT32},
299 bool asymmetric_inputs = false,
300 int num_threads = 1)
301 : batches_(batches), units_(units) {
302 int total_input_size = 1;
303 for (size_t i = 0; i < input.shape.size(); ++i) {
304 total_input_size *= input.shape[i];
305 }
306 input_size_ = total_input_size / batches_;
307
308 input_ = AddInput(input);
309 weights_ = AddInput(weights);
310
311 TensorData bias{TensorType_FLOAT32, {units_}};
312 bias_ = AddInput(bias);
313
314 output_ = AddOutput(output);
315
316 auto options = CreateFullyConnectedOptions(
317 builder_, ActivationFunctionType_RELU,
318 tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
319 false, asymmetric_inputs)
320 .Union();
321 SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
322 BuiltinOptions_FullyConnectedOptions, options);
323 resolver_ = absl::make_unique<SingleOpResolver>(
324 BuiltinOperator_FULLY_CONNECTED,
325 ops::builtin::Register_FULLY_CONNECTED_PIE());
326 BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
327 num_threads, /*allow_fp32_relax_to_fp16=*/false,
328 /*apply_delegate=*/false);
329 }
SetBias(const std::vector<float> & f)330 void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
SetWeights(const std::vector<float> & data)331 void SetWeights(const std::vector<float>& data) {
332 SymmetricQuantizeAndPopulate(weights_, data);
333 }
334
SetSignedWeights(std::initializer_list<float> f)335 void SetSignedWeights(std::initializer_list<float> f) {
336 SignedSymmetricQuantizeAndPopulate(weights_, f);
337 }
338
SetInput(const std::vector<float> & f)339 void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
GetOutput()340 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()341 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
342
input_size()343 int input_size() { return input_size_; }
num_units()344 int num_units() { return units_; }
num_batches()345 int num_batches() { return batches_; }
346
347 protected:
348 int input_;
349 int weights_;
350 int bias_;
351 int output_;
352
353 int batches_;
354 int units_;
355 int input_size_;
356 };
357
358 const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
359 {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
360 {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
361 {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
362 });
363
364 class FloatFullyConnectedOpTest : public SingleOpTest {
365 protected:
GetKernelMap()366 const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
367 return *kKernelMap;
368 }
369 };
370
371 const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
372 {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
373 {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
374 });
375
376 class QuantizedFullyConnectedOpTest : public SingleOpTest {
377 protected:
GetKernelMap()378 const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
379 return *kKernelMapNoPie;
380 }
381 };
382
383 const auto kKernelMapHybrid = new std::map<string, TfLiteRegistration*>({
384 {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
385 // Only Pie supports the hybrid path, so the optimized kernel should fall
386 // back to the Pie path in such cases.
387 {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
388 });
389
390 // Hybrid mode is used by the Pie quantized kernel.
391 class HybridFullyConnectedOpTest : public SingleOpTest {
392 protected:
GetKernelMap()393 const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
394 return *kKernelMapHybrid;
395 }
396 };
397
398 // TODO(ahentz): add more small tests like this one, focused on making sure the
399 // calculations are correct.
TEST_P(FloatFullyConnectedOpTest,SimpleTest)400 TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
401 FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2,
402 /*input=*/{TensorType_FLOAT32, {2, 10}});
403 m.SetWeights({
404 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
405 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
406 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
407 });
408 m.SetBias({1, 2, 3});
409
410 m.SetInput({
411 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
412 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
413 });
414
415 m.Invoke();
416
417 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
418 EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
419 }
420
TEST_P(FloatFullyConnectedOpTest,SimpleTest2)421 TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
422 FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2,
423 /*input=*/{TensorType_FLOAT32, {2, 2}});
424 m.SetWeights({
425 2, 4, // u = 0
426 });
427 m.SetBias({1});
428
429 m.SetInput({
430 1, 2, // b = 0
431 2, 1, // b = 1
432 });
433
434 m.Invoke();
435
436 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
437 EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
438 }
439
TEST(FloatFullyConnectedOpTest,SimpleTestNoBias)440 TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
441 // The optimized kernel assumes that the bias is specified.
442 FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
443 /*units=*/1, /*batches=*/2,
444 /*input=*/{TensorType_FLOAT32, {2, 2}},
445 /*output=*/{TensorType_FLOAT32},
446 /*keep_num_dims=*/false,
447 /*bias_tensor_optional=*/true);
448 m.SetWeights({
449 2, 4, // u = 0
450 });
451
452 m.SetInput({
453 1, 2, // b = 0
454 2, 1, // b = 1
455 });
456
457 m.Invoke();
458
459 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
460 EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
461 }
462
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8)463 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
464 QuantizedFullyConnectedOpModel m(
465 GetRegistration(), /*units=*/3, /*batches*/ 2,
466 /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
467 /*output=*/{TensorType_UINT8, {}, -127, 128});
468
469 // input_product_scale < output_scale was not true.
470 m.SetWeights<uint8_t>({
471 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
472 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
473 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
474 });
475 m.SetBias({1, 2, 3});
476
477 m.SetInput<uint8_t>({
478 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
479 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
480 });
481
482 m.Invoke();
483
484 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
485 ElementsAreArray(ArrayFloatNear({
486 24, 25, 26, //
487 58, 59, 60, //
488 })));
489 EXPECT_THAT(m.GetOutput<uint8_t>(),
490 ElementsAre(151, 152, 153, 185, 186, 187));
491 }
492
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8NoBias)493 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) {
494 QuantizedFullyConnectedOpModel m(
495 GetRegistration(), /*units=*/3, /*batches*/ 2,
496 /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
497 /*output=*/{TensorType_UINT8, {}, -127, 128},
498 /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
499 /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
500 /*FullyConnectedOptionsWeightsFormat weights_format =*/
501 FullyConnectedOptionsWeightsFormat_DEFAULT,
502 /*add_bias_for_quantized =*/false);
503
504 // input_product_scale < output_scale was not true.
505 m.SetWeights<uint8_t>({
506 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
507 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
508 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
509 });
510
511 m.SetInput<uint8_t>({
512 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
513 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
514 });
515
516 m.Invoke();
517
518 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
519 ElementsAreArray(ArrayFloatNear({
520 23, 23, 23, //
521 57, 57, 57, //
522 })));
523 EXPECT_THAT(m.GetOutput<uint8_t>(),
524 ElementsAre(150, 150, 150, 184, 184, 184));
525 }
526
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8)527 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
528 QuantizedFullyConnectedOpModel m(
529 GetRegistration(), /*units=*/3, /*batches*/ 2,
530 /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
531 /*output=*/{TensorType_INT8, {}, -127, 128});
532
533 // input_product_scale < output_scale was not true.
534 m.SetWeights<int8_t>({
535 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
536 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
537 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
538 });
539 m.SetBias({1, 2, 3});
540
541 m.SetInput<int8_t>({
542 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
543 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
544 });
545
546 m.Invoke();
547
548 EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
549 ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
550 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
551 }
552
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16)553 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16) {
554 const float scale = 128.0 / 65536;
555 QuantizedFullyConnectedOpModel m(
556 GetRegistration(), /*units=*/3, /*batches*/ 2,
557 /*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
558 /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0});
559
560 // input_product_scale < output_scale was not true.
561 m.SetWeights<int8_t>({
562 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
563 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
564 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
565 });
566 m.SetBias64({1, 2, 3});
567
568 m.SetInput<int16_t>({
569 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
570 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
571 });
572
573 m.Invoke();
574
575 EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
576 ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
577 EXPECT_THAT(m.GetOutput<int16_t>(),
578 ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
579 }
580
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8NoBias)581 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8NoBias) {
582 QuantizedFullyConnectedOpModel m(
583 GetRegistration(), /*units=*/3, /*batches*/ 2,
584 /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
585 /*output=*/{TensorType_INT8, {}, -127, 128},
586 /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
587 /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
588 /*FullyConnectedOptionsWeightsFormat weights_format =*/
589 FullyConnectedOptionsWeightsFormat_DEFAULT,
590 /*add_bias_for_quantized =*/false);
591
592 // input_product_scale < output_scale was not true.
593 m.SetWeights<int8_t>({
594 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
595 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
596 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
597 });
598
599 m.SetInput<int8_t>({
600 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
601 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
602 });
603
604 m.Invoke();
605
606 EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
607 ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
608 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
609 }
610
611 // Test the GEMV path.
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestSingleBatchQuantizedInt8)612 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestSingleBatchQuantizedInt8) {
613 QuantizedFullyConnectedOpModel m(
614 GetRegistration(), /*units=*/4, /*batches*/ 1,
615 /*input=*/{TensorType_INT8, {1, 10}, -63.5, 64},
616 /*output=*/{TensorType_INT8, {}, -127, 128});
617
618 // input_product_scale < output_scale was not true.
619 m.SetWeights<int8_t>({
620 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
621 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
622 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
623 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 3
624 });
625 m.SetBias({1, 2, 3, 4});
626
627 m.SetInput<int8_t>({
628 1, 2, 3, 4, 5, 6, 7, -8, 9, -10 // b = 1
629 });
630
631 m.Invoke();
632
633 EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
634 ElementsAreArray(ArrayFloatNear({58, 59, 60, 61})));
635 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(57, 58, 59, 60));
636 }
637
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8)638 TEST_P(QuantizedFullyConnectedOpTest,
639 SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8) {
640 // real_multiplier = 2.
641 QuantizedFullyConnectedOpModel m(
642 GetRegistration(), /*units=*/3, /*batches*/ 2,
643 /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
644 /*output=*/{TensorType_UINT8, {}, -63.5, 64});
645
646 m.SetWeights<uint8_t>({
647 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
648 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
649 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
650 });
651 m.SetBias({1, 2, 3});
652
653 m.SetInput<uint8_t>({
654 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
655 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
656 });
657
658 m.Invoke();
659
660 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
661 ElementsAreArray(ArrayFloatNear({
662 24, 25, 26, // first batch
663 58, 59, 60, // second batch
664 })));
665 EXPECT_THAT(m.GetOutput<uint8_t>(),
666 ElementsAre(175, 177, 179, 243, 245, 247));
667 }
668
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Int8)669 TEST_P(QuantizedFullyConnectedOpTest,
670 SimpleTestQuantizedOutputMultiplierGreaterThan1Int8) {
671 // real_multiplier = 2.
672 QuantizedFullyConnectedOpModel m(
673 GetRegistration(), /*units=*/3, /*batches*/ 2,
674 /*input=*/{TensorType_INT8, {2, 10}, -127, 128},
675 /*output=*/{TensorType_INT8, {}, -63.5, 64});
676
677 m.SetWeights<int8_t>({
678 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
679 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
680 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
681 });
682 m.SetBias({1, 2, 3});
683
684 m.SetInput<int8_t>({
685 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
686 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
687 });
688
689 m.Invoke();
690
691 EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
692 ElementsAreArray(ArrayFloatNear({
693 24, 25, 26, // first batch
694 58, 59, 60, // second batch
695 })));
696 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(47, 49, 51, 115, 117, 119));
697 }
698
SimpleTestQuantizedInt16OutputCase(TfLiteRegistration * registration,int input_depth,int output_depth,int batches,FullyConnectedOptionsWeightsFormat weights_format)699 void SimpleTestQuantizedInt16OutputCase(
700 TfLiteRegistration* registration, int input_depth, int output_depth,
701 int batches, FullyConnectedOptionsWeightsFormat weights_format) {
702 const uint8_t kWeightsZeroPoint = 128;
703 const float kWeightsScale = 1.f / 128.f;
704 const uint8_t kInputZeroPoint = 128;
705 const float kInputScale = 1.f / 128.f;
706 const float kInputMin = (0 - kInputZeroPoint) * kInputScale;
707 const float kInputMax = (255 - kInputZeroPoint) * kInputScale;
708 // Output ranges in [-8..8] encoded as int16
709 const float kOutputScale = 8.f / 32768.f;
710 const float kOutputMin = -32768 * kOutputScale;
711 const float kOutputMax = 32767 * kOutputScale;
712
713 QuantizedFullyConnectedOpModel m(
714 registration, output_depth, batches,
715 /*input=*/
716 {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
717 /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
718 /*keep_num_dims=*/false,
719 /*bias_tensor_optional=*/false,
720 /*activation_func=*/ActivationFunctionType_NONE, weights_format);
721
722 std::mt19937 random_engine;
723 // Some compilers don't support uint8_t for uniform_distribution.
724 std::uniform_int_distribution<uint32_t> weights_dist(
725 0, std::numeric_limits<uint8_t>::max());
726
727 std::vector<float> weights_data(input_depth * output_depth);
728 for (auto& w : weights_data) {
729 uint8_t q = static_cast<uint8_t>(weights_dist(random_engine));
730 w = (q - kWeightsZeroPoint) * kWeightsScale;
731 }
732
733 // Based on weights_format, enforce any shape requirement for that format/path
734 // and set the (possibly shuffled) weights.
735 switch (weights_format) {
736 case FullyConnectedOptionsWeightsFormat_DEFAULT:
737 m.SetWeights<uint8_t>(weights_data);
738 break;
739 case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
740 // The shuffled path currently supports only a restrictive subset of
741 // shapes, described by the following assertions:
742 CHECK_EQ(input_depth % 16, 0);
743 CHECK_EQ(output_depth % 4, 0);
744 CHECK(batches == 1 || batches == 4);
745 m.ShuffleAndSetWeights<uint8_t>(weights_data, input_depth, output_depth);
746 break;
747 default:
748 LOG(FATAL) << "Unhandled weights format";
749 }
750
751 // Some compilers don't support uint8_t for uniform_distribution.
752 std::uniform_int_distribution<uint32_t> input_dist(
753 0, std::numeric_limits<uint8_t>::max());
754 std::vector<float> input_data(input_depth * batches);
755 for (auto& i : input_data) {
756 uint8_t q = static_cast<uint8_t>(input_dist(random_engine));
757 i = (q - kInputZeroPoint) * kInputScale;
758 }
759
760 std::vector<float> bias_data(output_depth);
761 // As the output ranges in [-8, 8], it's reasonable to have bias values
762 // in [-1, 1], this won't result in too much saturation.
763 std::uniform_real_distribution<float> bias_dist(-1.f, 1.f);
764 for (auto& b : bias_data) {
765 b = bias_dist(random_engine);
766 }
767
768 m.SetBias(bias_data);
769 m.SetInput<uint8_t>(input_data);
770
771 m.Invoke();
772
773 std::vector<float> expected_output_data(output_depth * batches);
774 for (int b = 0; b < batches; b++) {
775 for (int o = 0; o < output_depth; o++) {
776 float accum = bias_data[o];
777 for (int i = 0; i < input_depth; i++) {
778 accum +=
779 input_data[b * input_depth + i] * weights_data[o * input_depth + i];
780 }
781 accum = std::min(accum, kOutputMax);
782 accum = std::max(accum, kOutputMin);
783 expected_output_data[b * output_depth + o] = accum;
784 }
785 }
786
787 EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
788 ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f)));
789 }
790
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputDefaultWeights)791 TEST_P(QuantizedFullyConnectedOpTest,
792 SimpleTestQuantizedInt16OutputDefaultWeights) {
793 for (int input_depth : {1, 3, 10, 100}) {
794 for (int output_depth : {1, 3, 10, 100}) {
795 for (int batch : {1, 3, 10, 100}) {
796 SimpleTestQuantizedInt16OutputCase(
797 GetRegistration(), input_depth, output_depth, batch,
798 FullyConnectedOptionsWeightsFormat_DEFAULT);
799 }
800 }
801 }
802 }
803
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights)804 TEST_P(QuantizedFullyConnectedOpTest,
805 SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) {
806 // The shuffled weights block shape is 4x16. The shape of the weights matrix
807 // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16.
808 // This means that output_depth must be a multiple of 4, and input_depth must
809 // be a multiple of 16.
810 for (int input_depth_numblocks : {1, 3}) {
811 for (int output_depth_numblocks : {1, 3}) {
812 int input_depth = 16 * input_depth_numblocks;
813 int output_depth = 4 * output_depth_numblocks;
814 // The fast shuffled path is currently supporting only batch sizes of 1
815 // and 4. The idea is that the whole point of that path is to go as fast
816 // as possible for small batch size, which requires fully specializing
817 // it for each batch size, and for larger batch sizes the generic
818 // gemmlowp-based implementation is fast enough.
819 for (int batch : {1, 4}) {
820 SimpleTestQuantizedInt16OutputCase(
821 GetRegistration(), input_depth, output_depth, batch,
822 FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
823 }
824 }
825 }
826 }
827
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedUint8)828 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedUint8) {
829 HybridFullyConnectedOpModel m(
830 /*units=*/3, /*batches=*/2,
831 /*input=*/{TensorType_FLOAT32, {2, 10}},
832 /*weights=*/
833 {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid
834
835 m.SetWeights({
836 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
837 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
838 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
839 });
840 m.SetBias({1, 2, 3});
841
842 m.SetInput({
843 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
844 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
845 });
846
847 m.Invoke();
848
849 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
850 {
851 24, 25, 26, //
852 58, 59, 60, //
853 },
854 /*max_abs_error=*/1.3f)));
855 }
856
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8)857 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
858 HybridFullyConnectedOpModel m(
859 /*units=*/3, /*batches=*/2,
860 /*input=*/{TensorType_FLOAT32, {2, 10}},
861 /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid
862
863 m.SetSignedWeights({
864 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
865 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
866 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
867 });
868 m.SetBias({1, 2, 3});
869
870 m.SetInput({
871 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
872 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
873 });
874
875 m.Invoke();
876
877 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
878 {
879 24, 25, 26, //
880 58, 59, 60, //
881 },
882 /*max_abs_error=*/1.3f)));
883 }
884
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8MultiThreaded)885 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) {
886 for (int num_threads = 1; num_threads <= 4; ++num_threads) {
887 HybridFullyConnectedOpModel m(
888 /*units=*/3, /*batches=*/4,
889 /*input=*/{TensorType_FLOAT32, {4, 10}},
890 /*weights=*/
891 {TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
892 /*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false,
893 /*num_threads=*/num_threads); // Hybrid
894
895 m.SetSignedWeights({
896 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
897 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
898 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
899 });
900 m.SetBias({1, 2, 3});
901
902 m.SetInput({
903 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
904 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
905 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 2
906 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 3
907 });
908
909 m.Invoke();
910
911 EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3));
912 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
913 {
914 24, 25, 26, //
915 58, 59, 60, //
916 24, 25, 26, //
917 58, 59, 60, //
918 },
919 /*max_abs_error=*/1.3f)));
920 }
921 }
922
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedUint8)923 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
924 HybridFullyConnectedOpModel m(
925 /*units=*/3, /*batches=*/2,
926 /*input=*/{TensorType_FLOAT32, {2, 10}},
927 /*weights=*/
928 {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
929 /*asymmetric_quantize_input*/ true); // Hybrid asymmetric
930
931 m.SetWeights({
932 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
933 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
934 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
935 });
936 m.SetBias({1, 2, 3});
937
938 m.SetInput({
939 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
940 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
941 });
942
943 m.Invoke();
944
945 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
946 {
947 24, 25, 26, //
948 58, 59, 60, //
949 },
950 /*max_abs_error=*/0.64f)));
951 }
952
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedInt8)953 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
954 HybridFullyConnectedOpModel m(
955 /*units=*/3, /*batches=*/2,
956 /*input=*/{TensorType_FLOAT32, {2, 10}},
957 /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
958 {TensorType_FLOAT32},
959 /*asymmetric_quantize_input*/ true);
960
961 m.SetSignedWeights({
962 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
963 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
964 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
965 });
966 m.SetBias({1, 2, 3});
967
968 m.SetInput({
969 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
970 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
971 });
972
973 m.Invoke();
974
975 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
976 {
977 24, 25, 26, //
978 58, 59, 60, //
979 },
980 /*max_abs_error=*/1.3f)));
981 }
982
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput)983 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
984 // Note that it is not required that the first dimension be the number of
985 // batches. All we care is that the input can be evenly distributed in
986 // batches. In this case, we need the input to have multiples of '2'.
987 FloatFullyConnectedOpModel m(GetRegistration(),
988 /*units=*/3, /*batches=*/2,
989 /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
990 m.SetWeights({
991 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
992 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
993 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
994 });
995 m.SetBias({1, 2, 3});
996
997 m.SetInput({
998 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch
999 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch
1000 });
1001
1002 m.Invoke();
1003
1004 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1005 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1006 24, 25, 26, // first batch
1007 58, 59, 60, // second batch
1008 }));
1009 }
1010
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput4DOutput)1011 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput4DOutput) {
1012 // Note that it is not required that the first dimension be the number of
1013 // batches. All we care is that the input can be evenly distributed in
1014 // batches. In this case, we need the input to have multiples of '2'.
1015 FloatFullyConnectedOpModel m(GetRegistration(),
1016 /*units=*/3, /*batches=*/2,
1017 /*input=*/{TensorType_FLOAT32, {1, 2, 1, 10}},
1018 /*output=*/{TensorType_FLOAT32},
1019 /*keep_num_dims=*/true);
1020 m.SetWeights({
1021 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1022 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1023 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
1024 });
1025 m.SetBias({1, 2, 3});
1026
1027 m.SetInput({
1028 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch
1029 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch
1030 });
1031
1032 m.Invoke();
1033
1034 EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 1, 3));
1035 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1036 24, 25, 26, // first batch
1037 58, 59, 60, // second batch
1038 }));
1039 }
1040
1041 #ifdef GTEST_HAS_DEATH_TEST
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInputInvalidShape)1042 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInputInvalidShape) {
1043 // Note that it is not required that the first dimension be the number of
1044 // batches. But it is required that the last dimension is the 'input_dim'.
1045 //
1046 // For this particular test, it is required for the output to be reformattable
1047 // into a shape of form {4, 1, 5, ?} but since the output size (the product of
1048 // output dimensions: units times batches) is 6, this is not possible.
1049 EXPECT_DEATH(FloatFullyConnectedOpModel m(
1050 GetRegistration(), /*units=*/3, /*batches=*/2,
1051 /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}},
1052 /*output=*/{TensorType_FLOAT32},
1053 /*keep_num_dims=*/true),
1054 "Cannot allocate tensors");
1055 }
1056 #endif
1057
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedUint8)1058 TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantizedUint8) {
1059 QuantizedFullyConnectedOpModel m(
1060 GetRegistration(), /*units=*/3, /*batches=*/2,
1061 /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
1062 /*output=*/{TensorType_UINT8, {}, -127, 128});
1063
1064 // input_product_scale < output_scale was not true.
1065 m.SetWeights<uint8_t>({
1066 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1067 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1068 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
1069 });
1070 m.SetBias({1, 2, 3});
1071
1072 m.SetInput<uint8_t>({
1073 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1074 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
1075 });
1076
1077 m.Invoke();
1078
1079 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1080 ElementsAreArray(ArrayFloatNear({
1081 24, 25, 26, //
1082 58, 59, 60, //
1083 })));
1084 EXPECT_THAT(m.GetOutput<uint8_t>(),
1085 ElementsAre(151, 152, 153, 185, 186, 187));
1086 }
1087
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8)1088 TEST_P(QuantizedFullyConnectedOpTest,
1089 SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8) {
1090 // real_multiplier = 2.
1091 QuantizedFullyConnectedOpModel m(
1092 GetRegistration(), /*units=*/3, /*batches=*/2,
1093 /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128},
1094 /*output=*/{TensorType_UINT8, {}, -63.5, 64});
1095
1096 m.SetWeights<uint8_t>({
1097 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1098 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1099 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
1100 });
1101 m.SetBias({1, 2, 3});
1102
1103 m.SetInput<uint8_t>({
1104 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1105 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
1106 });
1107
1108 m.Invoke();
1109
1110 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1111 ElementsAreArray(ArrayFloatNear({
1112 24, 25, 26, // first batch
1113 58, 59, 60, // second batch
1114 })));
1115 EXPECT_THAT(m.GetOutput<uint8_t>(),
1116 ElementsAre(175, 177, 179, 243, 245, 247));
1117 }
1118
1119 INSTANTIATE_TEST_SUITE_P(
1120 FloatFullyConnectedOpTest, FloatFullyConnectedOpTest,
1121 ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
1122
1123 INSTANTIATE_TEST_SUITE_P(
1124 QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest,
1125 ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
1126
1127 // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
1128 // to debug errors and doesn't necessarily test all the important details.
TEST_P(FloatFullyConnectedOpTest,BlackBoxTest)1129 TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
1130 FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2,
1131 /*input=*/{TensorType_FLOAT32, {2, 8}});
1132 m.SetWeights(
1133 {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
1134 -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,
1135 -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307,
1136 -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
1137 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782,
1138 -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638,
1139 -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540,
1140 -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382,
1141 -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
1142 -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308,
1143 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958,
1144 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863,
1145 -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612,
1146 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583,
1147 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284,
1148 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480,
1149 -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041,
1150 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764,
1151 -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627,
1152 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172,
1153 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324,
1154 0.145408, -0.221897});
1155 m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
1156 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
1157 0.048478, -0.032270, 0.175688, -0.085662});
1158
1159 const int input_sequence_size = sizeof(fully_connected_input) /
1160 sizeof(float) /
1161 (m.input_size() * m.num_batches());
1162 for (int i = 0; i < input_sequence_size; i++) {
1163 // TODO(ahentz): This is what the original test was doing: two equal
1164 // batches per invocation. We could instead use two different batches.
1165 float* batch_start = fully_connected_input + i * m.input_size();
1166 float* batch_end = batch_start + m.input_size();
1167 m.SetInput(0, batch_start, batch_end);
1168 m.SetInput(m.input_size(), batch_start, batch_end);
1169
1170 m.Invoke();
1171
1172 float* golden_start = fully_connected_golden_output + i * m.num_units();
1173 float* golden_end = golden_start + m.num_units();
1174 std::vector<float> expected;
1175 expected.insert(expected.end(), golden_start, golden_end);
1176 expected.insert(expected.end(), golden_start, golden_end);
1177
1178 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1179 }
1180 }
1181
1182 template <typename T>
1183 class SparseFullyConnectedOpModel : public SingleOpModel {
1184 public:
SparseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & weights,const std::vector<T> & weights_data,int num_threads=1,bool symmetric_quantize_weights=false)1185 SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
1186 int batches, const TensorData& input,
1187 const TensorData& weights,
1188 const std::vector<T>& weights_data,
1189 int num_threads = 1,
1190 bool symmetric_quantize_weights = false)
1191 : batches_(batches), units_(units) {
1192 int total_input_size = 1;
1193 for (size_t i = 0; i < input.shape.size(); ++i) {
1194 total_input_size *= input.shape[i];
1195 }
1196 input_size_ = total_input_size / batches_;
1197
1198 input_ = AddInput(input);
1199 weights_ =
1200 AddConstSparseInput(weights, weights_data, symmetric_quantize_weights);
1201
1202 TensorData bias{input.type, {units_}};
1203 bias_ = AddInput(bias);
1204
1205 output_ = AddOutput({input.type});
1206
1207 SetBuiltinOp(
1208 BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
1209 CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
1210 .Union());
1211 resolver_ = absl::make_unique<SingleOpResolver>(
1212 BuiltinOperator_FULLY_CONNECTED, registration);
1213 BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
1214 num_threads, /*allow_fp32_relax_to_fp16=*/false,
1215 /*apply_delegate=*/false);
1216 }
SetBias(const std::vector<T> & data)1217 void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
SetInput(const std::vector<T> & data)1218 void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
GetOutput()1219 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()1220 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1221
input_size()1222 int input_size() { return input_size_; }
num_units()1223 int num_units() { return units_; }
num_batches()1224 int num_batches() { return batches_; }
1225
1226 protected:
1227 int input_;
1228 int weights_;
1229 int bias_;
1230 int output_;
1231
1232 int batches_;
1233 int units_;
1234 int input_size_;
1235 };
1236
1237 class SparseFullyConnectedOpTest : public SingleOpTest {
1238 protected:
GetKernelMap()1239 const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
1240 return *kKernelMapNoPie;
1241 }
1242 };
1243
TEST_P(SparseFullyConnectedOpTest,SimpleTest)1244 TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
1245 std::initializer_list<float> weight_data = {
1246 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1247 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1248 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
1249 };
1250 TensorData weight = {};
1251 weight.type = TensorType_FLOAT32;
1252 weight.shape = {3, 10};
1253 weight.traversal_order = {0, 1};
1254 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1255 SparseFullyConnectedOpModel<float> m(
1256 GetRegistration(), /*units=*/3, /*batches=*/2,
1257 /*input=*/{TensorType_FLOAT32, {2, 10}}, weight, weight_data);
1258 m.SetBias({1, 2, 3});
1259
1260 m.SetInput({
1261 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1262 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
1263 });
1264
1265 m.Invoke();
1266
1267 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1268 EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
1269 }
1270
TEST_P(SparseFullyConnectedOpTest,SimpleTest2)1271 TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
1272 std::initializer_list<float> weight_data = {
1273 2, 4 // u = 0
1274 };
1275 TensorData weight = {};
1276 weight.type = TensorType_FLOAT32;
1277 weight.shape = {1, 2};
1278 weight.traversal_order = {0, 1};
1279 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1280 SparseFullyConnectedOpModel<float> m(
1281 GetRegistration(), /*units=*/1, /*batches=*/2,
1282 /*input=*/{TensorType_FLOAT32, {2, 2}}, weight, weight_data);
1283 m.SetBias({1});
1284
1285 m.SetInput({
1286 1, 2, // b = 0
1287 2, 1 // b = 1
1288 });
1289
1290 m.Invoke();
1291
1292 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
1293 EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
1294 }
1295
TEST_P(SparseFullyConnectedOpTest,Simple1x4Test)1296 TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) {
1297 std::initializer_list<float> weight_data = {
1298 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0
1299 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1
1300 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2
1301 };
1302 TensorData weight = {};
1303 weight.type = TensorType_FLOAT32;
1304 weight.shape = {3, 12};
1305 weight.traversal_order = {0, 1, 2};
1306 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1307 weight.block_map = {1};
1308 weight.block_size = {4};
1309 SparseFullyConnectedOpModel<float> m(GetRegistration(),
1310 /*units=*/3, /*batches=*/2,
1311 /*input=*/{TensorType_FLOAT32, {2, 12}},
1312 weight, weight_data);
1313 m.SetBias({1, 2, 3});
1314
1315 m.SetInput({
1316 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0
1317 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1
1318 });
1319
1320 m.Invoke();
1321
1322 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1323 EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1324 }
1325
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreaded)1326 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreaded) {
1327 std::initializer_list<float> weight_data = {
1328 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0
1329 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1
1330 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2
1331 };
1332 TensorData weight = {};
1333 weight.type = TensorType_FLOAT32;
1334 weight.shape = {3, 12};
1335 weight.traversal_order = {0, 1, 2};
1336 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1337 weight.block_map = {1};
1338 weight.block_size = {4};
1339 for (int num_threads = 1; num_threads <= 4; num_threads++) {
1340 SparseFullyConnectedOpModel<float> m(
1341 GetRegistration(),
1342 /*units=*/3, /*batches=*/2,
1343 /*input=*/{TensorType_FLOAT32, {2, 12}}, weight, weight_data,
1344 num_threads);
1345 m.SetBias({1, 2, 3});
1346
1347 m.SetInput({
1348 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0
1349 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1
1350 });
1351
1352 m.Invoke();
1353
1354 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1355 EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1356 }
1357 }
1358
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreadedMoreBatches)1359 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreadedMoreBatches) {
1360 std::initializer_list<float> weight_data = {
1361 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0
1362 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1
1363 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2
1364 };
1365 TensorData weight = {};
1366 weight.type = TensorType_FLOAT32;
1367 weight.shape = {3, 12};
1368 weight.traversal_order = {0, 1, 2};
1369 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1370 weight.block_map = {1};
1371 weight.block_size = {4};
1372 for (int num_threads = 1; num_threads <= 4; num_threads++) {
1373 SparseFullyConnectedOpModel<float> m(
1374 GetRegistration(),
1375 /*units=*/3, /*batches=*/6,
1376 /*input=*/{TensorType_FLOAT32, {6, 12}}, weight, weight_data,
1377 num_threads);
1378 m.SetBias({1, 2, 3});
1379
1380 m.SetInput({
1381 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0
1382 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1
1383 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 2
1384 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 3
1385 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 4
1386 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 5
1387 });
1388
1389 m.Invoke();
1390
1391 EXPECT_THAT(m.GetOutputShape(), ElementsAre(6, 3));
1392 EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, // b = 0
1393 81, 82, 83, // b = 1
1394 289, 290, 291, // b = 2
1395 81, 82, 83, // b = 3
1396 289, 290, 291, // b = 4
1397 81, 82, 83 // b = 5
1398 ));
1399 }
1400 }
1401
TEST_P(SparseFullyConnectedOpTest,SparseHybrid1x16Test)1402 TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) {
1403 std::initializer_list<float> weight_data = {
1404 /* 1st row */
1405 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1406 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1407 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1408 10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1409 /* 2nd row */
1410 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1411 0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1412 -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1413 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1414 /* 3rd row */
1415 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1416 0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1417 -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1418 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1419 /* 4th row */
1420 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1421 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1422 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1423 8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1424 TensorData weight = {};
1425 weight.type = TensorType_FLOAT32;
1426 weight.shape = {4, 48};
1427 weight.traversal_order = {0, 1, 2};
1428 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1429 weight.block_map = {1};
1430 weight.block_size = {16};
1431 SparseFullyConnectedOpModel<float> m(
1432 GetRegistration(),
1433 /*units=*/4, /*batches=*/2,
1434 /*input=*/{TensorType_FLOAT32, {2, 48}}, weight, weight_data,
1435 /*num_threads)=*/1, /*symmetric_quantize_weights=*/true);
1436 m.SetBias({1, 2, 3, 4});
1437 m.SetInput({
1438 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1439 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1440 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1441 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1442 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
1443 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
1444 -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
1445 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
1446 -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1447 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
1448 });
1449
1450 m.Invoke();
1451
1452 EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4));
1453 EXPECT_THAT(m.GetOutput(),
1454 ElementsAreArray(ArrayFloatNear(
1455 {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
1456 }
1457
TEST_P(SparseFullyConnectedOpTest,SparseHybrid1x16TestMultiThreaded)1458 TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) {
1459 std::initializer_list<float> weight_data = {
1460 /* 1st row */
1461 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1462 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1463 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1464 10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1465 /* 2nd row */
1466 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1467 0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1468 -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1469 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1470 /* 3rd row */
1471 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1472 0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1473 -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1474 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1475 /* 4th row */
1476 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1477 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1478 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1479 8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1480 TensorData weight = {};
1481 weight.type = TensorType_FLOAT32;
1482 weight.shape = {4, 48};
1483 weight.traversal_order = {0, 1, 2};
1484 weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1485 weight.block_map = {1};
1486 weight.block_size = {16};
1487 for (int num_threads = 1; num_threads <= 4; ++num_threads) {
1488 SparseFullyConnectedOpModel<float> m(
1489 GetRegistration(),
1490 /*units=*/4, /*batches=*/4,
1491 /*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data,
1492 /*num_threads)=*/num_threads, /*symmetric_quantize_weights=*/true);
1493 m.SetBias({1, 2, 3, 4});
1494 m.SetInput({
1495 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1496 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1497 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1498 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1499 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
1500 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
1501 -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
1502 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
1503 -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1504 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
1505 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1506 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1507 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1508 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1509 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 2
1510 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
1511 -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
1512 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
1513 -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1514 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 3
1515 });
1516
1517 m.Invoke();
1518
1519 EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4));
1520 EXPECT_THAT(m.GetOutput(),
1521 ElementsAreArray(ArrayFloatNear(
1522 {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0, 0,
1523 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0},
1524 1e-3)));
1525 }
1526 }
1527 // TODO(b/148391360): Add tests for unsupported sparsity format.
1528 // TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
1529
1530 INSTANTIATE_TEST_SUITE_P(
1531 SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
1532 ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
1533
1534 } // namespace
1535 } // namespace tflite
1536