• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <gtest/gtest.h>
16 #include "tensorflow/lite/interpreter.h"
17 #include "tensorflow/lite/kernels/register.h"
18 #include "tensorflow/lite/kernels/test_util.h"
19 #include "tensorflow/lite/model.h"
20 
21 namespace tflite {
22 namespace {
23 
24 using ::testing::ElementsAreArray;
25 
26 class L2NormOpModel : public SingleOpModel {
27  public:
L2NormOpModel(const std::initializer_list<int> input_shape,const TensorType tensor_type,const ActivationFunctionType activation_type)28   L2NormOpModel(const std::initializer_list<int> input_shape,
29                 const TensorType tensor_type,
30                 const ActivationFunctionType activation_type) {
31     TensorData data = TensorData{tensor_type};
32     if (tensor_type != TensorType_FLOAT32) {
33       data.min = -2.0;
34       data.max = 2.0;
35       data.scale = 2.0;
36       data.zero_point = 128;
37     }
38     input_ = AddInput(data);
39     if (tensor_type != TensorType_FLOAT32) {
40       data.min = -1.0;
41       data.max = 127.0 / 128.0;
42     }
43     output_ = AddOutput(data);
44     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
45                  CreateL2NormOptions(builder_, activation_type).Union());
46     BuildInterpreter({input_shape});
47   }
48 
SetInput(std::initializer_list<float> data)49   void SetInput(std::initializer_list<float> data) {
50     PopulateTensor(input_, data);
51   }
52 
53   template <typename T>
GetOutput()54   std::vector<T> GetOutput() {
55     return ExtractVector<T>(output_);
56   }
57 
58   template <typename T>
GetDequantizedOutput()59   std::vector<float> GetDequantizedOutput() {
60     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
61                          GetZeroPoint(output_));
62   }
63 
input() const64   int input() const { return input_; }
65 
66  private:
67   int input_;
68   int output_;
69 };
70 
TEST(L2NormOpTest,SimpleFloatTest)71 TEST(L2NormOpTest, SimpleFloatTest) {
72   L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
73                   ActivationFunctionType_NONE);
74   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
75   m.Invoke();
76   EXPECT_THAT(m.GetOutput<float>(),
77               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
78 }
79 
TEST(L2NormOpTest,ZerosVectorFloatTest)80 TEST(L2NormOpTest, ZerosVectorFloatTest) {
81   L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
82                   ActivationFunctionType_NONE);
83   m.SetInput({0, 0, 0, 0, 0, 0});
84   m.Invoke();
85   EXPECT_THAT(m.GetOutput<float>(),
86               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0})));
87 }
88 
TEST(L2NormOpTest,SimpleFloatWithRankLessThanFourTest)89 TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
90   L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
91   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
92   m.Invoke();
93   EXPECT_THAT(m.GetOutput<float>(),
94               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
95 }
96 
TEST(L2NormOpTest,MultipleBatchFloatTest)97 TEST(L2NormOpTest, MultipleBatchFloatTest) {
98   L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
99                   ActivationFunctionType_NONE);
100   m.SetInput({
101       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
102       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
103       -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
104   });
105   m.Invoke();
106   EXPECT_THAT(m.GetOutput<float>(),
107               ElementsAreArray({
108                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
109                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
110                   -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
111               }));
112 }
113 
TEST(L2NormOpTest,ZerosVectorUint8Test)114 TEST(L2NormOpTest, ZerosVectorUint8Test) {
115   L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
116 
117   m.QuantizeAndPopulate<uint8_t>(m.input(), {0, 0, 0, 0, 0, 0});
118   m.Invoke();
119   EXPECT_THAT(m.GetOutput<uint8_t>(),
120               ElementsAreArray({128, 128, 128, 128, 128, 128}));
121   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
122               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
123 }
124 
TEST(L2NormOpTest,SimpleUint8Test)125 TEST(L2NormOpTest, SimpleUint8Test) {
126   L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
127 
128   m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
129   m.Invoke();
130   EXPECT_THAT(m.GetOutput<uint8_t>(),
131               ElementsAreArray({58, 166, 173, 205, 83, 134}));
132   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
133               ElementsAreArray(
134                   ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
135 }
136 
TEST(L2NormOpTest,SimpleInt8Test)137 TEST(L2NormOpTest, SimpleInt8Test) {
138   L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
139 
140   m.QuantizeAndPopulate<int8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
141   m.Invoke();
142   EXPECT_THAT(m.GetOutput<int8_t>(),
143               ElementsAreArray({-70, 38, 45, 77, -45, 6}));
144 
145   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
146               ElementsAreArray(
147                   ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
148 }
149 
TEST(L2NormOpTest,ZerosVectorInt8Test)150 TEST(L2NormOpTest, ZerosVectorInt8Test) {
151   L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
152 
153   m.QuantizeAndPopulate<int8_t>(m.input(), {0, 0, 0, 0, 0, 0});
154   m.Invoke();
155   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
156 
157   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
158               ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
159 }
160 
TEST(L2NormOpTest,MultipleBatchUint8Test)161 TEST(L2NormOpTest, MultipleBatchUint8Test) {
162   L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
163 
164   m.QuantizeAndPopulate<uint8_t>(m.input(),
165                                  {
166                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
167                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
168                                      -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
169                                  });
170   m.Invoke();
171   EXPECT_THAT(m.GetOutput<uint8_t>(),
172               ElementsAreArray({
173                   58, 166, 173, 205, 83, 134,  // batch 1
174                   58, 166, 173, 205, 83, 134,  // batch 2
175                   58, 166, 173, 205, 83, 134,  // batch 3
176               }));
177   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
178               ElementsAreArray(ArrayFloatNear(
179                   {
180                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
181                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
182                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
183                   },
184                   0.1)));
185 }
186 
TEST(L2NormOpTest,MultipleBatchInt8Test)187 TEST(L2NormOpTest, MultipleBatchInt8Test) {
188   L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
189 
190   m.QuantizeAndPopulate<int8_t>(m.input(),
191                                 {
192                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 1
193                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 2
194                                     -1.1, 0.6, 0.7, 1.2, -0.7, 0.1,  // batch 3
195                                 });
196   m.Invoke();
197   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
198                                          -70, 38, 45, 77, -45, 6,  // batch 1
199                                          -70, 38, 45, 77, -45, 6,  // batch 2
200                                          -70, 38, 45, 77, -45, 6,  // batch 3
201                                      }));
202   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
203               ElementsAreArray(ArrayFloatNear(
204                   {
205                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 1
206                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 2
207                       -0.55, 0.3, 0.35, 0.6, -0.35, 0.05,  // batch 3
208                   },
209                   0.1)));
210 }
211 
212 }  // namespace
213 }  // namespace tflite
214