• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
16 
17 #include <cstdint>
18 #include <initializer_list>
19 #include <memory>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/testing/util.h"
26 
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 using ::testing::Test;
31 
32 class DimsAllocatingTest : public Test {
33  protected:
DimsAllocatingTest()34   DimsAllocatingTest() : allocated_dims_() {}
35 
~DimsAllocatingTest()36   ~DimsAllocatingTest() override {
37     for (TfLiteIntArray* dim : allocated_dims_) {
38       TfLiteIntArrayFree(dim);
39     }
40   }
41 
CreateDimArray(int size,std::initializer_list<int> dimensions)42   TfLiteIntArray* CreateDimArray(int size,
43                                  std::initializer_list<int> dimensions) {
44     TfLiteIntArray* dims = TfLiteIntArrayCreate(size);
45     allocated_dims_.push_back(dims);
46 
47     int i = 0;
48     for (const int dimension : dimensions) {
49       dims->data[i++] = dimension;
50     }
51 
52     return dims;
53   }
54 
55  private:
56   std::vector<TfLiteIntArray*> allocated_dims_;
57 };
58 
59 using tflite::delegate::nnapi::ExtractQuantLstmWeightsSubmatrix;
60 
61 class ExtractQuantLstmWeightsSubmatrixTest : public DimsAllocatingTest {};
62 
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,TopLeftSubmatrixIsExtracted)63 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopLeftSubmatrixIsExtracted) {
64   std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
65                                   11,  12,  13,  14,  15,   //
66                                   101, 102, 103, 104, 105,  //
67                                   111, 112, 113, 114, 115,  //
68                                   201, 202, 203, 204, 205,  //
69                                   211, 212, 213, 214, 215,  //
70                                   221, 222, 223, 224, 225,  //
71                                   231, 232, 233, 234, 235};
72   const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
73 
74   std::vector<uint8_t> submatrix;
75   const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 3});
76 
77   ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
78                                    0 /* offset_column */, weight_dims,
79                                    weights.data(), &submatrix);
80 
81   EXPECT_THAT(submatrix, ElementsAreArray({1, 2, 3, 11, 12, 13}));
82 }
83 
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,TopRightSubmatrixIsExtracted)84 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopRightSubmatrixIsExtracted) {
85   std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
86                                   11,  12,  13,  14,  15,   //
87                                   101, 102, 103, 104, 105,  //
88                                   111, 112, 113, 114, 115,  //
89                                   201, 202, 203, 204, 205,  //
90                                   211, 212, 213, 214, 215,  //
91                                   221, 222, 223, 224, 225,  //
92                                   231, 232, 233, 234, 235};
93   const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
94 
95   std::vector<uint8_t> submatrix;
96   const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
97 
98   ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
99                                    3 /* offset_column */, weight_dims,
100                                    weights.data(), &submatrix);
101 
102   EXPECT_THAT(submatrix, ElementsAreArray({4, 5, 14, 15}));
103 }
104 
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,RightCentralSubmatrixIsExtracted)105 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, RightCentralSubmatrixIsExtracted) {
106   std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
107                                   11,  12,  13,  14,  15,   //
108                                   101, 102, 103, 104, 105,  //
109                                   111, 112, 113, 114, 115,  //
110                                   201, 202, 203, 204, 205,  //
111                                   211, 212, 213, 214, 215,  //
112                                   221, 222, 223, 224, 225,  //
113                                   231, 232, 233, 234, 235};
114   const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
115 
116   std::vector<uint8_t> submatrix;
117   const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
118 
119   ExtractQuantLstmWeightsSubmatrix(
120       submatrix_dims, 1 * submatrix_dims->data[0] /* offset_row */,
121       3 /* offset_column */, weight_dims, weights.data(), &submatrix);
122 
123   EXPECT_THAT(submatrix, ElementsAreArray({104, 105, 114, 115}));
124 }
125 
126 using tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor;
127 
128 class QuantLstmWeightDecompTest : public DimsAllocatingTest {
129  protected:
QuantLstmWeightDecompTest()130   QuantLstmWeightDecompTest()
131       : weights_({1,   2,   3,   4,   5,    //
132                   11,  12,  13,  14,  15,   //
133                   101, 102, 103, 104, 105,  //
134                   111, 112, 113, 114, 115,  //
135                   201, 202, 203, 204, 205,  //
136                   211, 212, 213, 214, 215,  //
137                   221, 222, 223, 224, 225,  //
138                   231, 232, 233, 234, 235}),
139         // Creating the arrays empty, the size is set by the decomposition
140         // function
141         recurrent_to_input_(),
142         input_to_input_(),
143         recurrent_to_cell_(),
144         input_to_cell_(),
145         recurrent_to_forget_(),
146         input_to_forget_(),
147         recurrent_to_output_(),
148         input_to_output_() {
149     weight_dims_ = CreateDimArray(2, {8, 5});
150   }
151 
152   const std::vector<uint8_t> weights_;
153   const TfLiteIntArray* weight_dims_;
154   std::vector<uint8_t> recurrent_to_input_;
155   std::vector<uint8_t> input_to_input_;
156   std::vector<uint8_t> recurrent_to_cell_;
157   std::vector<uint8_t> input_to_cell_;
158   std::vector<uint8_t> recurrent_to_forget_;
159   std::vector<uint8_t> input_to_forget_;
160   std::vector<uint8_t> recurrent_to_output_;
161   std::vector<uint8_t> input_to_output_;
162 };
163 
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToInput)164 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToInput) {
165   DecomposeQuantLstmWeightsTensor(
166       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
167       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
168       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
169 
170   EXPECT_THAT(recurrent_to_input_, ElementsAreArray({1, 2,  //
171                                                      11, 12}));
172 }
173 
TEST_F(QuantLstmWeightDecompTest,ExtractInputToInput)174 TEST_F(QuantLstmWeightDecompTest, ExtractInputToInput) {
175   DecomposeQuantLstmWeightsTensor(
176       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
177       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
178       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
179 
180   EXPECT_THAT(input_to_input_, ElementsAreArray({3, 4, 5,  //
181                                                  13, 14, 15}));
182 }
183 
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToCell)184 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToCell) {
185   DecomposeQuantLstmWeightsTensor(
186       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
187       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
188       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
189 
190   EXPECT_THAT(recurrent_to_cell_, ElementsAreArray({101, 102,  //
191                                                     111, 112}));
192 }
193 
TEST_F(QuantLstmWeightDecompTest,ExtractInputToCell)194 TEST_F(QuantLstmWeightDecompTest, ExtractInputToCell) {
195   DecomposeQuantLstmWeightsTensor(
196       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
197       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
198       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
199 
200   EXPECT_THAT(input_to_cell_, ElementsAreArray({103, 104, 105,  //
201                                                 113, 114, 115}));
202 }
203 
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToForget)204 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToForget) {
205   DecomposeQuantLstmWeightsTensor(
206       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
207       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
208       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
209 
210   EXPECT_THAT(recurrent_to_forget_, ElementsAreArray({201, 202,  //
211                                                       211, 212}));
212 }
213 
TEST_F(QuantLstmWeightDecompTest,ExtractInputToForget)214 TEST_F(QuantLstmWeightDecompTest, ExtractInputToForget) {
215   DecomposeQuantLstmWeightsTensor(
216       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
217       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
218       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
219 
220   EXPECT_THAT(input_to_forget_, ElementsAreArray({203, 204, 205,  //
221                                                   213, 214, 215}));
222 }
223 
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToOutput)224 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToOutput) {
225   DecomposeQuantLstmWeightsTensor(
226       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
227       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
228       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
229 
230   EXPECT_THAT(recurrent_to_output_, ElementsAreArray({221, 222,  //
231                                                       231, 232}));
232 }
233 
TEST_F(QuantLstmWeightDecompTest,ExtractInputToOutput)234 TEST_F(QuantLstmWeightDecompTest, ExtractInputToOutput) {
235   DecomposeQuantLstmWeightsTensor(
236       weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
237       &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
238       &input_to_forget_, &recurrent_to_output_, &input_to_output_);
239 
240   EXPECT_THAT(input_to_output_, ElementsAreArray({223, 224, 225,  //
241                                                   233, 234, 235}));
242 }
243 
244 using tflite::delegate::nnapi::DecomposeBiasTensor;
245 
TEST(DecomposeBiasTensor,ExtractInputBias)246 TEST(DecomposeBiasTensor, ExtractInputBias) {
247   // clang-format off
248   std::vector<int32_t> biases
249       // inputGateBias
250       {-7876, 13488, -726, 32839,
251       // cellGateBias
252       39481, 48624, 48976, -21419,
253       // forgetGateBias
254       9206, -46884, -11693, -38724,
255       // outputGateBias
256       -58999, -17050, -41852, -40538};
257   // clang-format on
258 
259   std::vector<int32_t> input_bias;
260   std::vector<int32_t> cell_bias;
261   std::vector<int32_t> forget_bias;
262   std::vector<int32_t> output_bias;
263   DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
264                       &output_bias);
265 
266   EXPECT_THAT(input_bias, ElementsAreArray({-7876, 13488, -726, 32839}));
267 }
268 
TEST(DecomposeBiasTensor,ExtractCellBias)269 TEST(DecomposeBiasTensor, ExtractCellBias) {
270   // clang-format off
271   std::vector<int32_t> biases
272       // inputGateBias
273       {-7876, 13488, -726, 32839,
274       // cellGateBias
275       39481, 48624, 48976, -21419,
276       // forgetGateBias
277       9206, -46884, -11693, -38724,
278       // outputGateBias
279       -58999, -17050, -41852, -40538};
280   // clang-format on
281 
282   std::vector<int32_t> input_bias;
283   std::vector<int32_t> cell_bias;
284   std::vector<int32_t> forget_bias;
285   std::vector<int32_t> output_bias;
286   DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
287                       &output_bias);
288 
289   EXPECT_THAT(cell_bias, ElementsAreArray({39481, 48624, 48976, -21419}));
290 }
291 
TEST(DecomposeBiasTensor,ExtractForgetBias)292 TEST(DecomposeBiasTensor, ExtractForgetBias) {
293   // clang-format off
294   std::vector<int32_t> biases
295       // inputGateBias
296       {-7876, 13488, -726, 32839,
297       // cellGateBias
298       39481, 48624, 48976, -21419,
299       // forgetGateBias
300       9206, -46884, -11693, -38724,
301       // outputGateBias
302       -58999, -17050, -41852, -40538};
303   // clang-format on
304 
305   std::vector<int32_t> input_bias;
306   std::vector<int32_t> cell_bias;
307   std::vector<int32_t> forget_bias;
308   std::vector<int32_t> output_bias;
309   DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
310                       &output_bias);
311 
312   EXPECT_THAT(forget_bias, ElementsAreArray({9206, -46884, -11693, -38724}));
313 }
314 
TEST(DecomposeBiasTensor,ExtractOutputBias)315 TEST(DecomposeBiasTensor, ExtractOutputBias) {
316   // clang-format off
317   std::vector<int32_t> biases
318       // inputGateBias
319       {-7876, 13488, -726, 32839,
320       // cellGateBias
321       39481, 48624, 48976, -21419,
322       // forgetGateBias
323       9206, -46884, -11693, -38724,
324       // outputGateBias
325       -58999, -17050, -41852, -40538};
326   // clang-format on
327 
328   std::vector<int32_t> input_bias;
329   std::vector<int32_t> cell_bias;
330   std::vector<int32_t> forget_bias;
331   std::vector<int32_t> output_bias;
332   DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
333                       &output_bias);
334 
335   EXPECT_THAT(output_bias, ElementsAreArray({-58999, -17050, -41852, -40538}));
336 }
337 
338 }  // namespace
339