• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 
31 class L2NormOpModel : public SingleOpModel {
32  public:
L2NormOpModel(const std::initializer_list<int> input_shape,const TensorType tensor_type,const ActivationFunctionType activation_type)33   L2NormOpModel(const std::initializer_list<int> input_shape,
34                 const TensorType tensor_type,
35                 const ActivationFunctionType activation_type) {
36     TensorData data = TensorData{tensor_type};
37     if (tensor_type != TensorType_FLOAT32) {
38       data.min = -2.0;
39       data.max = 2.0;
40       data.scale = 2.0;
41       data.zero_point = 128;
42     }
43     input_ = AddInput(data);
44     if (tensor_type != TensorType_FLOAT32) {
45       data.min = -1.0;
46       data.max = 127.0 / 128.0;
47     }
48     output_ = AddOutput(data);
49     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
50                  CreateL2NormOptions(builder_, activation_type).Union());
51     BuildInterpreter({input_shape});
52   }
53 
SetInput(std::initializer_list<float> data)54   void SetInput(std::initializer_list<float> data) {
55     PopulateTensor(input_, data);
56   }
57 
58   template <typename T>
GetOutput()59   std::vector<T> GetOutput() {
60     return ExtractVector<T>(output_);
61   }
62 
63   template <typename T>
GetDequantizedOutput()64   std::vector<float> GetDequantizedOutput() {
65     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
66                          GetZeroPoint(output_));
67   }
68 
input() const69   int input() const { return input_; }
70 
71  private:
72   int input_;
73   int output_;
74 };
75 
TEST(L2NormOpTest,SimpleFloatTest)76 TEST(L2NormOpTest, SimpleFloatTest) {
77   L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
78                   ActivationFunctionType_NONE);
79   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
80   ASSERT_EQ(m.Invoke(), kTfLiteOk);
81   EXPECT_THAT(m.GetOutput<float>(),
82               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
83 }
84 
TEST(L2NormOpTest,ZerosVectorFloatTest)85 TEST(L2NormOpTest, ZerosVectorFloatTest) {
86   L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
87                   ActivationFunctionType_NONE);
88   m.SetInput({0, 0, 0, 0, 0, 0});
89   ASSERT_EQ(m.Invoke(), kTfLiteOk);
90   EXPECT_THAT(m.GetOutput<float>(),
91               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0})));
92 }
93 
TEST(L2NormOpTest,SimpleFloatWithRankLessThanFourTest)94 TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
95   L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
96   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
97   ASSERT_EQ(m.Invoke(), kTfLiteOk);
98   EXPECT_THAT(m.GetOutput<float>(),
99               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
100 }
101 
TEST(L2NormOpTest,MultipleBatchFloatTest)102 TEST(L2NormOpTest, MultipleBatchFloatTest) {
103   L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
104                   ActivationFunctionType_NONE);
105   m.SetInput({
106       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
107       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
108       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
109   });
110   ASSERT_EQ(m.Invoke(), kTfLiteOk);
111   EXPECT_THAT(m.GetOutput<float>(),
112               ElementsAreArray({
113                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
114                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
115                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
116               }));
117 }
118 
TEST(L2NormOpTest,ZerosVectorUint8Test)119 TEST(L2NormOpTest, ZerosVectorUint8Test) {
120   L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
121 
122   m.QuantizeAndPopulate<uint8_t>(m.input(), {0, 0, 0, 0, 0, 0});
123   ASSERT_EQ(m.Invoke(), kTfLiteOk);
124   EXPECT_THAT(m.GetOutput<uint8_t>(),
125               ElementsAreArray({128, 128, 128, 128, 128, 128}));
126   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
127               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
128 }
129 
TEST(L2NormOpTest,SimpleUint8Test)130 TEST(L2NormOpTest, SimpleUint8Test) {
131   L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
132 
133   m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
134   ASSERT_EQ(m.Invoke(), kTfLiteOk);
135   EXPECT_THAT(m.GetOutput<uint8_t>(),
136               ElementsAreArray({58, 166, 173, 205, 83, 134}));
137   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
138               ElementsAreArray(
139                   ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
140 }
141 
TEST(L2NormOpTest,SimpleInt8Test)142 TEST(L2NormOpTest, SimpleInt8Test) {
143   L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
144 
145   m.QuantizeAndPopulate<int8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
146   ASSERT_EQ(m.Invoke(), kTfLiteOk);
147   EXPECT_THAT(m.GetOutput<int8_t>(),
148               ElementsAreArray({-70, 38, 45, 77, -45, 6}));
149 
150   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
151               ElementsAreArray(
152                   ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
153 }
154 
TEST(L2NormOpTest,ZerosVectorInt8Test)155 TEST(L2NormOpTest, ZerosVectorInt8Test) {
156   L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
157 
158   m.QuantizeAndPopulate<int8_t>(m.input(), {0, 0, 0, 0, 0, 0});
159   ASSERT_EQ(m.Invoke(), kTfLiteOk);
160   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
161 
162   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
163               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
164 }
165 
TEST(L2NormOpTest,MultipleBatchUint8Test)166 TEST(L2NormOpTest, MultipleBatchUint8Test) {
167   L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
168 
169   m.QuantizeAndPopulate<uint8_t>(m.input(),
170                                  {
171                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
172                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
173                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
174                                  });
175   ASSERT_EQ(m.Invoke(), kTfLiteOk);
176   EXPECT_THAT(m.GetOutput<uint8_t>(),
177               ElementsAreArray({
178                   58, 166, 173, 205, 83, 134,  // batch 1
179                   58, 166, 173, 205, 83, 134,  // batch 2
180                   58, 166, 173, 205, 83, 134,  // batch 3
181               }));
182   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
183               ElementsAreArray(ArrayFloatNear(
184                   {
185                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
186                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
187                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
188                   },
189                   0.1)));
190 }
191 
TEST(L2NormOpTest,MultipleBatchInt8Test)192 TEST(L2NormOpTest, MultipleBatchInt8Test) {
193   L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
194 
195   m.QuantizeAndPopulate<int8_t>(m.input(),
196                                 {
197                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
198                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
199                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
200                                 });
201   ASSERT_EQ(m.Invoke(), kTfLiteOk);
202   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
203                                          -70, 38, 45, 77, -45, 6,  // batch 1
204                                          -70, 38, 45, 77, -45, 6,  // batch 2
205                                          -70, 38, 45, 77, -45, 6,  // batch 3
206                                      }));
207   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
208               ElementsAreArray(ArrayFloatNear(
209                   {
210                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
211                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
212                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
213                   },
214                   0.1)));
215 }
216 
217 }  // namespace
218 }  // namespace tflite
219