• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21 #include "tensorflow/lite/micro/testing/micro_test.h"
22 
23 namespace tflite {
24 namespace testing {
25 namespace {
26 
27 // The Logistic kernel assumes an output in the range [0, 1.0], leading to these
28 // quantization parameters.
29 const float quantized_output_scale = 1.0 / 255.0;
30 const int quantized_output_zero_point_int8 = -128;
31 
32 const int flat_size_basic = 10;
33 const int shape_basic[] = {2, 2, 5};
34 const float input_data_basic[] = {1, 2, 3, 4, 5, -1, -2, -3, -4, -5};
35 const float golden_basic[] = {0.73105858, 0.88079708, 0.95257413, 0.98201379,
36                               0.99330715, 0.26894142, 0.11920292, 0.04742587,
37                               0.01798621, 0.00669285};
38 
39 const int flat_size_wide_range = 10;
40 const int shape_wide_range[] = {2, 1, 5};
41 const float input_data_wide_range[]{
42     1.0, 2.0, 3.0, 4.0, 93.0, -1.0, -2.0, -3.0, -4.0, -93.0,
43 };
44 const float golden_wide_range[] = {
45     0.73105858, 0.88079708, 0.95257413, 0.98201379, 1.0,
46     0.26894142, 0.11920292, 0.04742587, 0.01798621, 0.0,
47 };
48 
49 template <typename T>
ValidateLogisticGoldens(TfLiteTensor * tensors,const int tensor_count,T * output_data,const T * golden,int output_dims_count,float tolerance)50 void ValidateLogisticGoldens(TfLiteTensor* tensors, const int tensor_count,
51                              T* output_data, const T* golden,
52                              int output_dims_count, float tolerance) {
53   int inputs_array_data[] = {1, 0};
54   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
55   int outputs_array_data[] = {1, 1};
56   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
57 
58   const TfLiteRegistration registration =
59       tflite::ops::micro::Register_LOGISTIC();
60   micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
61                              outputs_array, nullptr);
62 
63   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
64   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
65 
66   for (int i = 0; i < output_dims_count; ++i) {
67     TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], tolerance);
68   }
69 }
70 
TestLogisticFloat(const int * input_dims_data,const float * input_data,const float * golden,const int * output_dims_data,float * output_data)71 void TestLogisticFloat(const int* input_dims_data, const float* input_data,
72                        const float* golden, const int* output_dims_data,
73                        float* output_data) {
74   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
75   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
76   const int output_elements_count = ElementCount(*output_dims);
77 
78   constexpr int inputs_size = 1;
79   constexpr int outputs_size = 1;
80   constexpr int tensors_size = inputs_size + outputs_size;
81   TfLiteTensor tensors[tensors_size] = {
82       CreateTensor(input_data, input_dims),
83       CreateTensor(output_data, output_dims),
84   };
85 
86   ValidateLogisticGoldens(tensors, tensors_size, output_data, golden,
87                           output_elements_count, 1e-5);
88 }
89 
90 template <typename T>
TestLogisticQuantized(const int * input_dims_data,const float * input_data,T * input_quantized,const float input_scale,const int input_zero_point,const float * golden,T * golden_quantized,const int * output_dims_data,const float output_scale,const int output_zero_point,int8_t * output_data)91 void TestLogisticQuantized(const int* input_dims_data, const float* input_data,
92                            T* input_quantized, const float input_scale,
93                            const int input_zero_point, const float* golden,
94                            T* golden_quantized, const int* output_dims_data,
95                            const float output_scale,
96                            const int output_zero_point, int8_t* output_data) {
97   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
98   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
99   const int output_elements_count = ElementCount(*output_dims);
100 
101   constexpr int inputs_size = 1;
102   constexpr int outputs_size = 1;
103   constexpr int tensors_size = inputs_size + outputs_size;
104   TfLiteTensor tensors[tensors_size] = {
105       CreateQuantizedTensor(input_data, input_quantized, input_dims,
106                             input_scale, input_zero_point),
107       CreateQuantizedTensor(output_data, output_dims, output_scale,
108                             output_zero_point),
109   };
110 
111   tflite::Quantize(golden, golden_quantized, output_elements_count,
112                    output_scale, output_zero_point);
113   ValidateLogisticGoldens(tensors, tensors_size, output_data, golden_quantized,
114                           output_elements_count, 1.0);
115 }
116 
117 }  // namespace
118 }  // namespace testing
119 }  // namespace tflite
120 
121 TF_LITE_MICRO_TESTS_BEGIN
122 
TF_LITE_MICRO_TEST(LogisticFloatBasicShouldMatchGolden)123 TF_LITE_MICRO_TEST(LogisticFloatBasicShouldMatchGolden) {
124   float output_data[tflite::testing::flat_size_basic];
125   tflite::testing::TestLogisticFloat(
126       tflite::testing::shape_basic, tflite::testing::input_data_basic,
127       tflite::testing::golden_basic, tflite::testing::shape_basic, output_data);
128 }
129 
TF_LITE_MICRO_TEST(LogisticQuantizedInt8BasicShouldMatchGolden)130 TF_LITE_MICRO_TEST(LogisticQuantizedInt8BasicShouldMatchGolden) {
131   const float input_scale = 0.1;
132   const int input_zero_point = 0;
133   int8_t input_quantized[tflite::testing::flat_size_basic];
134   int8_t golden_quantized[tflite::testing::flat_size_basic];
135   int8_t output_data[tflite::testing::flat_size_basic];
136 
137   tflite::testing::TestLogisticQuantized(
138       tflite::testing::shape_basic, tflite::testing::input_data_basic,
139       input_quantized, input_scale, input_zero_point,
140       tflite::testing::golden_basic, golden_quantized,
141       tflite::testing::shape_basic, tflite::testing::quantized_output_scale,
142       tflite::testing::quantized_output_zero_point_int8, output_data);
143 }
144 
TF_LITE_MICRO_TEST(LogisticFloatWideRangeShouldMatchGolden)145 TF_LITE_MICRO_TEST(LogisticFloatWideRangeShouldMatchGolden) {
146   float output_data[tflite::testing::flat_size_wide_range];
147   tflite::testing::TestLogisticFloat(
148       tflite::testing::shape_wide_range, tflite::testing::input_data_wide_range,
149       tflite::testing::golden_wide_range, tflite::testing::shape_wide_range,
150       output_data);
151 }
152 
TF_LITE_MICRO_TEST(LogisticQuantizedInt8WideRangeShouldMatchGolden)153 TF_LITE_MICRO_TEST(LogisticQuantizedInt8WideRangeShouldMatchGolden) {
154   const float input_scale = 1.0;
155   const int input_zero_point = 0;
156   int8_t input_quantized[tflite::testing::flat_size_wide_range];
157   int8_t golden_quantized[tflite::testing::flat_size_wide_range];
158   int8_t output_data[tflite::testing::flat_size_wide_range];
159 
160   tflite::testing::TestLogisticQuantized(
161       tflite::testing::shape_wide_range, tflite::testing::input_data_wide_range,
162       input_quantized, input_scale, input_zero_point,
163       tflite::testing::golden_wide_range, golden_quantized,
164       tflite::testing::shape_wide_range,
165       tflite::testing::quantized_output_scale,
166       tflite::testing::quantized_output_zero_point_int8, output_data);
167 }
168 
169 TF_LITE_MICRO_TESTS_END
170