• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21 #include "tensorflow/lite/micro/testing/micro_test.h"
22 
23 namespace tflite {
24 namespace testing {
25 namespace {
26 
27 #if !defined(XTENSA)
28 // The Softmax kernel assumes an output in the range [0, 1.0], leading to these
29 // quantization parameters.
30 const float output_scale_int8 = 1.0f / 256.0f;
31 const float output_scale_uint8 = 1.0f / 256.0f;
32 const float output_scale_int16 = 1.0f / 32768.0f;
33 const int output_zero_point_int8 = -128;
34 const int output_zero_point_uint8 = 0;
35 const int output_zero_point_int16 = 0;
36 
37 // Empirical tolerance in quantization space
38 const float tolerance_int16 = 7.0;
39 
40 // 1-dimensional test data.
41 const int flat_size_1d = 5;
42 const int shape_1d[] = {1, 5};
43 const float input_data_1d[] = {1.0, 2.0, 3.0, 4.0, 5.0};
44 const float golden_1d[] = {0.011656231, 0.031684921, 0.086128544, 0.234121657,
45                            0.636408647};
46 
47 #endif
48 // 2-dimensional test data.
49 const int flat_size_2d = 10;
50 const int shape_2d[] = {2, 2, 5};
51 const float input_data_2d[] = {1.0,  2.0,  3.0,  4.0,  5.0,
52                                -1.0, -2.0, -3.0, -4.0, -5.0};
53 const float golden_2d[] = {0.011656231, 0.031684921, 0.086128544, 0.234121657,
54                            0.636408647, 0.636408647, 0.234121657, 0.086128544,
55                            0.031684921, 0.011656231};
56 
57 #if !defined(XTENSA)
58 // 3-dimensional test data.
59 const int flat_size_3d = 60;
60 const int shape_3d[] = {3, 3, 4, 5};
61 const float input_data_3d[] = {
62     // c = 0
63     // h = 0
64     3.00, 6.00, -5.00, 4.00, -9.00,
65     // h = 1
66     -10.00, -10.00, -8.00, 2.00, 2.00,
67     // h = 2
68     8.00, -5.00, -8.00, 5.00, -6.00,
69     // h = 3
70     -8.00, 6.00, 1.00, -10.00, -8.00,
71 
72     // c = 1
73     // h = 0
74     7.00, 6.00, -10.00, -4.00, -5.00,
75     // h = 1
76     2.00, 7.00, 9.00, -9.00, 7.00,
77     // h = 2
78     -4.00, -2.00, 8.00, 2.00, 2.00,
79     // h = 3
80     3.00, 6.00, 6.00, 2.00, 4.00,
81 
82     // c = 2
83     // h = 0
84     9.00, 7.00, -7.00, 0.00, 4.00,
85     // h = 1
86     -3.00, 8.00, 8.00, -3.00, -4.00,
87     // h = 2
88     -9.00, -9.00, 4.00, -8.00, -1.00,
89     // h = 3
90     -10.00, -2.00, 6.00, -7.00, 0.00};
91 
92 float golden_3d[] = {
93     // c = 0
94     // h = 0
95     0.042009463, 0.843782625, 0.000014093, 0.114193561, 0.000000258,
96     // h = 1
97     0.000003072, 0.000003072, 0.000022699, 0.499985578, 0.499985578,
98     // h = 2
99     0.952571219, 0.000002153, 0.000000107, 0.047425728, 0.000000792,
100     // h = 3
101     0.000000826, 0.993305397, 0.006692839, 0.000000112, 0.000000826,
102 
103     // c = 1
104     // h = 0
105     0.731046347, 0.268936922, 0.000000030, 0.000012210, 0.000004492,
106     // h = 1
107     0.000717124, 0.106430599, 0.786421666, 0.000000012, 0.106430599,
108     // h = 2
109     0.000006114, 0.000045174, 0.995015917, 0.002466398, 0.002466398,
110     // h = 3
111     0.022595176, 0.453836234, 0.453836234, 0.008312301, 0.061420055,
112 
113     // c = 2
114     // h = 0
115     0.875505904, 0.118486839, 0.000000099, 0.000108046, 0.005899112,
116     // h = 1
117     0.000008351, 0.499990113, 0.499990113, 0.000008351, 0.000003072,
118     // h = 2
119     0.000002245, 0.000002245, 0.993296627, 0.000006103, 0.006692780,
120     // h = 3
121     0.000000112, 0.000334520, 0.997191323, 0.000002254, 0.002471790};
122 
123 // 4-dimensional test data.
124 const int flat_size_4d = 120;
125 const int shape_4d[] = {4, 2, 3, 4, 5};
126 const float input_data_4d[] = {
127     // n = 0
128     // c = 0
129     // h = 0
130     3.00, 6.00, -5.00, 4.00, -9.00,
131     // h = 1
132     -10.00, -10.00, -8.00, 2.00, 2.00,
133     // h = 2
134     8.00, -5.00, -8.00, 5.00, -6.00,
135     // h = 3
136     -8.00, 6.00, 1.00, -10.00, -8.00,
137 
138     // c = 1
139     // h = 0
140     7.00, 6.00, -10.00, -4.00, -5.00,
141     // h = 1
142     2.00, 7.00, 9.00, -9.00, 7.00,
143     // h = 2
144     -4.00, -2.00, 8.00, 2.00, 2.00,
145     // h = 3
146     3.00, 6.00, 6.00, 2.00, 4.00,
147 
148     // c = 2
149     // h = 0
150     9.00, 7.00, -7.00, 0.00, 4.00,
151     // h = 1
152     -3.00, 8.00, 8.00, -3.00, -4.00,
153     // h = 2
154     -9.00, -9.00, 4.00, -8.00, -1.00,
155     // h = 3
156     -10.00, -2.00, 6.00, -7.00, 0.00,
157 
158     // n = 1
159     // c = 0
160     // h = 0
161     -9.00, -8.00, 6.00, -1.00, -5.00,
162     // h = 1
163     -10.00, -5.00, -10.00, 7.00, -2.00,
164     // h = 2
165     -5.00, -4.00, 1.00, 2.00, 2.00,
166     // h = 3
167     -2.00, -2.00, 1.00, 1.00, -4.00,
168 
169     // c = 1
170     // h = 0
171     -8.00, -3.00, 1.00, 1.00, -1.00,
172     // h = 1
173     -2.00, 6.00, -1.00, -5.00, 6.00,
174     // h = 2
175     -7.00, 8.00, 9.00, 0.00, 9.00,
176     // h = 3
177     -9.00, -5.00, -2.00, 0.00, 8.00,
178 
179     // c = 2
180     // h = 0
181     4.00, 2.00, -3.00, 5.00, 8.00,
182     // h = 1
183     -1.00, 1.00, -4.00, -9.00, 7.00,
184     // h = 2
185     3.00, -8.00, 0.00, 9.00, -4.00,
186     // h = 3
187     8.00, -1.00, 9.00, -9.00, 1.00};
188 
189 const float golden_4d[] = {
190     // n = 0
191     // c = 0
192     // h = 0
193     0.042009463, 0.843782625, 0.000014093, 0.114193561, 0.000000258,
194     // h = 1
195     0.000003072, 0.000003072, 0.000022699, 0.499985578, 0.499985578,
196     // h = 2
197     0.952571219, 0.000002153, 0.000000107, 0.047425728, 0.000000792,
198     // h = 3
199     0.000000826, 0.993305397, 0.006692839, 0.000000112, 0.000000826,
200 
201     // c = 1
202     // h = 0
203     0.731046347, 0.268936922, 0.000000030, 0.000012210, 0.000004492,
204     // h = 1
205     0.000717124, 0.106430599, 0.786421666, 0.000000012, 0.106430599,
206     // h = 2
207     0.000006114, 0.000045174, 0.995015917, 0.002466398, 0.002466398,
208     // h = 3
209     0.022595176, 0.453836234, 0.453836234, 0.008312301, 0.061420055,
210 
211     // c = 2
212     // h = 0
213     0.875505904, 0.118486839, 0.000000099, 0.000108046, 0.005899112,
214     // h = 1
215     0.000008351, 0.499990113, 0.499990113, 0.000008351, 0.000003072,
216     // h = 2
217     0.000002245, 0.000002245, 0.993296627, 0.000006103, 0.006692780,
218     // h = 3
219     0.000000112, 0.000334520, 0.997191323, 0.000002254, 0.002471790,
220 
221     // n = 1
222     // c = 0
223     // h = 0
224     0.000000306, 0.000000831, 0.999071142, 0.000911035, 0.000016686,
225     // h = 1
226     0.000000041, 0.000006143, 0.000000041, 0.999870380, 0.000123394,
227     // h = 2
228     0.000384554, 0.001045327, 0.155140254, 0.421714933, 0.421714933,
229     // h = 3
230     0.023637081, 0.023637081, 0.474763454, 0.474763454, 0.003198931,
231 
232     // c = 1
233     // h = 0
234     0.000057299, 0.008503973, 0.464301197, 0.464301197, 0.062836334,
235     // h = 1
236     0.000167625, 0.499684188, 0.000455653, 0.000008346, 0.499684188,
237     // h = 2
238     0.000000048, 0.155354299, 0.422296769, 0.000052116, 0.422296769,
239     // h = 3
240     0.000000041, 0.000002259, 0.000045383, 0.000335334, 0.999616982,
241 
242     // c = 2
243     // h = 0
244     0.017107856, 0.002315297, 0.000015600, 0.046503973, 0.934057274,
245     // h = 1
246     0.000334516, 0.002471755, 0.000016655, 0.000000112, 0.997176963,
247     // h = 2
248     0.002472313, 0.000000041, 0.000123089, 0.997402302, 0.000002254,
249     // h = 3
250     0.268866557, 0.000033181, 0.730855076, 0.000000011, 0.000245175};
251 
252 #endif
253 template <typename T>
ValidateSoftmaxGoldens(TfLiteTensor * tensors,const int tensor_count,T * output_data,const T * expected_output,int output_dims_count,float tolerance)254 void ValidateSoftmaxGoldens(TfLiteTensor* tensors, const int tensor_count,
255                             T* output_data, const T* expected_output,
256                             int output_dims_count, float tolerance) {
257   TfLiteSoftmaxParams builtin_data = {1.0f};
258 
259   int inputs_array_data[] = {1, 0};
260   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
261   int outputs_array_data[] = {1, 1};
262   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
263 
264   const TfLiteRegistration registration = Register_SOFTMAX();
265   micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
266                              outputs_array, &builtin_data);
267 
268   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
269   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
270 
271   for (int i = 0; i < output_dims_count; ++i) {
272     TF_LITE_MICRO_EXPECT_NEAR(expected_output[i], output_data[i], tolerance);
273   }
274 }
275 
276 #if !defined(XTENSA)
TestSoftmaxFloat(const int * input_dims_data,const float * input_data,const int * output_dims_data,const float * expected_output_data,float * output_data)277 void TestSoftmaxFloat(const int* input_dims_data, const float* input_data,
278                       const int* output_dims_data,
279                       const float* expected_output_data, float* output_data) {
280   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
281   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
282   const int output_dims_count = ElementCount(*output_dims);
283 
284   constexpr int inputs_size = 1;
285   constexpr int outputs_size = 1;
286   constexpr int tensors_size = inputs_size + outputs_size;
287   TfLiteTensor tensors[tensors_size] = {
288       CreateTensor(input_data, input_dims),
289       CreateTensor(output_data, output_dims),
290   };
291 
292   ValidateSoftmaxGoldens(tensors, tensors_size, output_data,
293                          expected_output_data, output_dims_count, 1e-5);
294 }
295 #endif
296 
297 template <typename inputT, typename outputT>
TestSoftmaxQuantized(const int * input_dims_data,const float * input_data,inputT * input_quantized,float input_scale,int input_zero_point,const int * output_dims_data,const float * golden,outputT * golden_quantized,float output_scale,int output_zero_point,outputT * output_data,float tolerance=1.0)298 void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data,
299                           inputT* input_quantized, float input_scale,
300                           int input_zero_point, const int* output_dims_data,
301                           const float* golden, outputT* golden_quantized,
302                           float output_scale, int output_zero_point,
303                           outputT* output_data, float tolerance = 1.0) {
304   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
305   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
306   const int output_dims_count = ElementCount(*output_dims);
307 
308   constexpr int inputs_size = 1;
309   constexpr int outputs_size = 1;
310   constexpr int tensors_size = inputs_size + outputs_size;
311   TfLiteTensor tensors[tensors_size] = {
312       CreateQuantizedTensor(input_data, input_quantized, input_dims,
313                             input_scale, input_zero_point),
314       CreateQuantizedTensor(output_data, output_dims, output_scale,
315                             output_zero_point),
316   };
317 
318   Quantize(golden, golden_quantized, output_dims_count, output_scale,
319            output_zero_point);
320 
321   ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized,
322                          output_dims_count, tolerance);
323 }
324 
325 }  // namespace
326 }  // namespace testing
327 }  // namespace tflite
328 
329 TF_LITE_MICRO_TESTS_BEGIN
330 
331 #if !defined(XTENSA)
TF_LITE_MICRO_TEST(Softmax1DFloatShouldMatchGolden)332 TF_LITE_MICRO_TEST(Softmax1DFloatShouldMatchGolden) {
333   float output_data[tflite::testing::flat_size_1d];
334   tflite::testing::TestSoftmaxFloat(
335       tflite::testing ::shape_1d, tflite::testing::input_data_1d,
336       tflite::testing::shape_1d, tflite::testing::golden_1d, output_data);
337 }
338 
TF_LITE_MICRO_TEST(Softmax1DQuantizedUInt8ShouldMatchGolden)339 TF_LITE_MICRO_TEST(Softmax1DQuantizedUInt8ShouldMatchGolden) {
340   const float input_scale = 0.1f;
341   const int input_zero_point = 128;
342 
343   uint8_t input_quantized[tflite::testing::flat_size_1d];
344   uint8_t golden_quantized[tflite::testing::flat_size_1d];
345   uint8_t output_data[tflite::testing::flat_size_1d];
346   tflite::testing::TestSoftmaxQuantized(
347       tflite::testing::shape_1d, tflite::testing::input_data_1d,
348       input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d,
349       tflite::testing::golden_1d, golden_quantized,
350       tflite::testing::output_scale_uint8,
351       tflite::testing::output_zero_point_uint8, output_data);
352 }
353 
TF_LITE_MICRO_TEST(Softmax1DQuantizedInt8ShouldMatchGolden)354 TF_LITE_MICRO_TEST(Softmax1DQuantizedInt8ShouldMatchGolden) {
355   const float input_scale = 0.1f;
356   const int input_zero_point = 0;
357 
358   int8_t input_quantized[tflite::testing::flat_size_1d];
359   int8_t golden_quantized[tflite::testing::flat_size_1d];
360   int8_t output_data[tflite::testing::flat_size_1d];
361   tflite::testing::TestSoftmaxQuantized(
362       tflite::testing::shape_1d, tflite::testing::input_data_1d,
363       input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d,
364       tflite::testing::golden_1d, golden_quantized,
365       tflite::testing::output_scale_int8,
366       tflite::testing::output_zero_point_int8, output_data);
367 }
368 
TF_LITE_MICRO_TEST(Softmax1DQuantizedInt16ShouldMatchGolden)369 TF_LITE_MICRO_TEST(Softmax1DQuantizedInt16ShouldMatchGolden) {
370   const float input_scale = 0.1f;
371   const int input_zero_point = 0;
372 
373   int16_t input_quantized[tflite::testing::flat_size_1d];
374   int16_t golden_quantized[tflite::testing::flat_size_1d];
375   int16_t output_data[tflite::testing::flat_size_1d];
376   tflite::testing::TestSoftmaxQuantized(
377       tflite::testing::shape_1d, tflite::testing::input_data_1d,
378       input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d,
379       tflite::testing::golden_1d, golden_quantized,
380       tflite::testing::output_scale_int16,
381       tflite::testing::output_zero_point_int16, output_data);
382 }
383 
TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden)384 TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden) {
385   float output_data[tflite::testing::flat_size_2d];
386   tflite::testing::TestSoftmaxFloat(
387       tflite::testing ::shape_2d, tflite::testing::input_data_2d,
388       tflite::testing::shape_2d, tflite::testing::golden_2d, output_data);
389 }
390 
TF_LITE_MICRO_TEST(Softmax2DQuantizedUInt8ShouldMatchGolden)391 TF_LITE_MICRO_TEST(Softmax2DQuantizedUInt8ShouldMatchGolden) {
392   const float input_scale = 0.1f;
393   const int input_zero_point = 128;
394 
395   uint8_t input_quantized[tflite::testing::flat_size_2d];
396   uint8_t golden_quantized[tflite::testing::flat_size_2d];
397   uint8_t output_data[tflite::testing::flat_size_2d];
398   tflite::testing::TestSoftmaxQuantized(
399       tflite::testing::shape_2d, tflite::testing::input_data_2d,
400       input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
401       tflite::testing::golden_2d, golden_quantized,
402       tflite::testing::output_scale_uint8,
403       tflite::testing::output_zero_point_uint8, output_data);
404 }
405 
TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8ShouldMatchGolden)406 TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8ShouldMatchGolden) {
407   const float input_scale = 0.1f;
408   const int input_zero_point = 0;
409 
410   int8_t input_quantized[tflite::testing::flat_size_2d];
411   int8_t golden_quantized[tflite::testing::flat_size_2d];
412   int8_t output_data[tflite::testing::flat_size_2d];
413   tflite::testing::TestSoftmaxQuantized(
414       tflite::testing::shape_2d, tflite::testing::input_data_2d,
415       input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
416       tflite::testing::golden_2d, golden_quantized,
417       tflite::testing::output_scale_int8,
418       tflite::testing::output_zero_point_int8, output_data);
419 }
420 
TF_LITE_MICRO_TEST(Softmax2DQuantizedInt16ShouldMatchGolden)421 TF_LITE_MICRO_TEST(Softmax2DQuantizedInt16ShouldMatchGolden) {
422   const float input_scale = 0.1f;
423   const int input_zero_point = 0;
424 
425   int16_t input_quantized[tflite::testing::flat_size_2d];
426   int16_t golden_quantized[tflite::testing::flat_size_2d];
427   int16_t output_data[tflite::testing::flat_size_2d];
428   tflite::testing::TestSoftmaxQuantized(
429       tflite::testing::shape_2d, tflite::testing::input_data_2d,
430       input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
431       tflite::testing::golden_2d, golden_quantized,
432       tflite::testing::output_scale_int16,
433       tflite::testing::output_zero_point_int16, output_data);
434 }
435 
TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden)436 TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden) {
437   float output_data[tflite::testing::flat_size_3d];
438   tflite::testing::TestSoftmaxFloat(
439       tflite::testing ::shape_3d, tflite::testing::input_data_3d,
440       tflite::testing::shape_3d, tflite::testing::golden_3d, output_data);
441 }
442 
TF_LITE_MICRO_TEST(Softmax3DQuantizedUInt8ShouldMatchGolden)443 TF_LITE_MICRO_TEST(Softmax3DQuantizedUInt8ShouldMatchGolden) {
444   const float input_scale = 0.1f;
445   const int input_zero_point = 128;
446 
447   uint8_t input_quantized[tflite::testing::flat_size_3d];
448   uint8_t golden_quantized[tflite::testing::flat_size_3d];
449   uint8_t output_data[tflite::testing::flat_size_3d];
450   tflite::testing::TestSoftmaxQuantized(
451       tflite::testing::shape_3d, tflite::testing::input_data_3d,
452       input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d,
453       tflite::testing::golden_3d, golden_quantized,
454       tflite::testing::output_scale_uint8,
455       tflite::testing::output_zero_point_uint8, output_data);
456 }
457 
TF_LITE_MICRO_TEST(Softmax3DQuantizedInt8ShouldMatchGolden)458 TF_LITE_MICRO_TEST(Softmax3DQuantizedInt8ShouldMatchGolden) {
459   const float input_scale = 0.1f;
460   const int input_zero_point = 0;
461 
462   int8_t input_quantized[tflite::testing::flat_size_3d];
463   int8_t golden_quantized[tflite::testing::flat_size_3d];
464   int8_t output_data[tflite::testing::flat_size_3d];
465   tflite::testing::TestSoftmaxQuantized(
466       tflite::testing::shape_3d, tflite::testing::input_data_3d,
467       input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d,
468       tflite::testing::golden_3d, golden_quantized,
469       tflite::testing::output_scale_int8,
470       tflite::testing::output_zero_point_int8, output_data);
471 }
472 
TF_LITE_MICRO_TEST(Softmax3DQuantizedInt16ShouldMatchGolden)473 TF_LITE_MICRO_TEST(Softmax3DQuantizedInt16ShouldMatchGolden) {
474   const float input_scale = 0.1f;
475   const int input_zero_point = 0;
476 
477   int16_t input_quantized[tflite::testing::flat_size_3d];
478   int16_t golden_quantized[tflite::testing::flat_size_3d];
479   int16_t output_data[tflite::testing::flat_size_3d];
480   tflite::testing::TestSoftmaxQuantized(
481       tflite::testing::shape_3d, tflite::testing::input_data_3d,
482       input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d,
483       tflite::testing::golden_3d, golden_quantized,
484       tflite::testing::output_scale_int16,
485       tflite::testing::output_zero_point_int16, output_data,
486       tflite::testing::tolerance_int16);
487 }
488 
TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden)489 TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden) {
490   float output_data[tflite::testing::flat_size_4d];
491   tflite::testing::TestSoftmaxFloat(
492       tflite::testing ::shape_4d, tflite::testing::input_data_4d,
493       tflite::testing::shape_4d, tflite::testing::golden_4d, output_data);
494 }
495 
TF_LITE_MICRO_TEST(Softmax4DQuantizedUInt8ShouldMatchGolden)496 TF_LITE_MICRO_TEST(Softmax4DQuantizedUInt8ShouldMatchGolden) {
497   const float input_scale = 0.1f;
498   const int input_zero_point = 128;
499 
500   uint8_t input_quantized[tflite::testing::flat_size_4d];
501   uint8_t golden_quantized[tflite::testing::flat_size_4d];
502   uint8_t output_data[tflite::testing::flat_size_4d];
503   tflite::testing::TestSoftmaxQuantized(
504       tflite::testing::shape_4d, tflite::testing::input_data_4d,
505       input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d,
506       tflite::testing::golden_4d, golden_quantized,
507       tflite::testing::output_scale_uint8,
508       tflite::testing::output_zero_point_uint8, output_data);
509 }
510 
TF_LITE_MICRO_TEST(Softmax4DQuantizedInt8ShouldMatchGolden)511 TF_LITE_MICRO_TEST(Softmax4DQuantizedInt8ShouldMatchGolden) {
512   const float input_scale = 0.1f;
513   const int input_zero_point = 0;
514 
515   int8_t input_quantized[tflite::testing::flat_size_4d];
516   int8_t golden_quantized[tflite::testing::flat_size_4d];
517   int8_t output_data[tflite::testing::flat_size_4d];
518   tflite::testing::TestSoftmaxQuantized(
519       tflite::testing::shape_4d, tflite::testing::input_data_4d,
520       input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d,
521       tflite::testing::golden_4d, golden_quantized,
522       tflite::testing::output_scale_int8,
523       tflite::testing::output_zero_point_int8, output_data);
524 }
525 
TF_LITE_MICRO_TEST(Softmax4DQuantizedInt16ShouldMatchGolden)526 TF_LITE_MICRO_TEST(Softmax4DQuantizedInt16ShouldMatchGolden) {
527   const float input_scale = 0.1f;
528   const int input_zero_point = 0;
529 
530   int16_t input_quantized[tflite::testing::flat_size_4d];
531   int16_t golden_quantized[tflite::testing::flat_size_4d];
532   int16_t output_data[tflite::testing::flat_size_4d];
533   tflite::testing::TestSoftmaxQuantized(
534       tflite::testing::shape_4d, tflite::testing::input_data_4d,
535       input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d,
536       tflite::testing::golden_4d, golden_quantized,
537       tflite::testing::output_scale_int16,
538       tflite::testing::output_zero_point_int16, output_data,
539       tflite::testing::tolerance_int16);
540 }
541 #endif
542 
TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8InputInt16OutputShouldMatchGolden)543 TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8InputInt16OutputShouldMatchGolden) {
544   const float input_scale = 0.1f;
545   const int input_zero_point = 0;
546   const float output_scale = 1.0f / 65536.0f;
547   const int output_zero_point = -32768;
548 
549   int8_t input_quantized[tflite::testing::flat_size_2d];
550   int16_t golden_quantized[tflite::testing::flat_size_2d];
551   int16_t output_data[tflite::testing::flat_size_2d];
552   tflite::testing::TestSoftmaxQuantized(
553       tflite::testing::shape_2d, tflite::testing::input_data_2d,
554       input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
555       tflite::testing::golden_2d, golden_quantized, output_scale,
556       output_zero_point, output_data);
557 }
558 
559 TF_LITE_MICRO_TESTS_END
560