1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
16
17 #include <cstdint>
18 #include <initializer_list>
19 #include <memory>
20 #include <vector>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/testing/util.h"
26
27 namespace {
28
29 using ::testing::ElementsAreArray;
30 using ::testing::Test;
31
32 class DimsAllocatingTest : public Test {
33 protected:
DimsAllocatingTest()34 DimsAllocatingTest() : allocated_dims_() {}
35
~DimsAllocatingTest()36 ~DimsAllocatingTest() override {
37 for (TfLiteIntArray* dim : allocated_dims_) {
38 TfLiteIntArrayFree(dim);
39 }
40 }
41
CreateDimArray(int size,std::initializer_list<int> dimensions)42 TfLiteIntArray* CreateDimArray(int size,
43 std::initializer_list<int> dimensions) {
44 TfLiteIntArray* dims = TfLiteIntArrayCreate(size);
45 allocated_dims_.push_back(dims);
46
47 int i = 0;
48 for (const int dimension : dimensions) {
49 dims->data[i++] = dimension;
50 }
51
52 return dims;
53 }
54
55 private:
56 std::vector<TfLiteIntArray*> allocated_dims_;
57 };
58
59 using tflite::delegate::nnapi::ExtractQuantLstmWeightsSubmatrix;
60
61 class ExtractQuantLstmWeightsSubmatrixTest : public DimsAllocatingTest {};
62
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,TopLeftSubmatrixIsExtracted)63 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopLeftSubmatrixIsExtracted) {
64 std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
65 11, 12, 13, 14, 15, //
66 101, 102, 103, 104, 105, //
67 111, 112, 113, 114, 115, //
68 201, 202, 203, 204, 205, //
69 211, 212, 213, 214, 215, //
70 221, 222, 223, 224, 225, //
71 231, 232, 233, 234, 235};
72 const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
73
74 std::vector<uint8_t> submatrix;
75 const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 3});
76
77 ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
78 0 /* offset_column */, weight_dims,
79 weights.data(), &submatrix);
80
81 EXPECT_THAT(submatrix, ElementsAreArray({1, 2, 3, 11, 12, 13}));
82 }
83
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,TopRightSubmatrixIsExtracted)84 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopRightSubmatrixIsExtracted) {
85 std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
86 11, 12, 13, 14, 15, //
87 101, 102, 103, 104, 105, //
88 111, 112, 113, 114, 115, //
89 201, 202, 203, 204, 205, //
90 211, 212, 213, 214, 215, //
91 221, 222, 223, 224, 225, //
92 231, 232, 233, 234, 235};
93 const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
94
95 std::vector<uint8_t> submatrix;
96 const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
97
98 ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
99 3 /* offset_column */, weight_dims,
100 weights.data(), &submatrix);
101
102 EXPECT_THAT(submatrix, ElementsAreArray({4, 5, 14, 15}));
103 }
104
TEST_F(ExtractQuantLstmWeightsSubmatrixTest,RightCentralSubmatrixIsExtracted)105 TEST_F(ExtractQuantLstmWeightsSubmatrixTest, RightCentralSubmatrixIsExtracted) {
106 std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
107 11, 12, 13, 14, 15, //
108 101, 102, 103, 104, 105, //
109 111, 112, 113, 114, 115, //
110 201, 202, 203, 204, 205, //
111 211, 212, 213, 214, 215, //
112 221, 222, 223, 224, 225, //
113 231, 232, 233, 234, 235};
114 const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
115
116 std::vector<uint8_t> submatrix;
117 const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
118
119 ExtractQuantLstmWeightsSubmatrix(
120 submatrix_dims, 1 * submatrix_dims->data[0] /* offset_row */,
121 3 /* offset_column */, weight_dims, weights.data(), &submatrix);
122
123 EXPECT_THAT(submatrix, ElementsAreArray({104, 105, 114, 115}));
124 }
125
126 using tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor;
127
128 class QuantLstmWeightDecompTest : public DimsAllocatingTest {
129 protected:
QuantLstmWeightDecompTest()130 QuantLstmWeightDecompTest()
131 : weights_({1, 2, 3, 4, 5, //
132 11, 12, 13, 14, 15, //
133 101, 102, 103, 104, 105, //
134 111, 112, 113, 114, 115, //
135 201, 202, 203, 204, 205, //
136 211, 212, 213, 214, 215, //
137 221, 222, 223, 224, 225, //
138 231, 232, 233, 234, 235}),
139 // Creating the arrays empty, the size is set by the decomposition
140 // function
141 recurrent_to_input_(),
142 input_to_input_(),
143 recurrent_to_cell_(),
144 input_to_cell_(),
145 recurrent_to_forget_(),
146 input_to_forget_(),
147 recurrent_to_output_(),
148 input_to_output_() {
149 weight_dims_ = CreateDimArray(2, {8, 5});
150 }
151
152 const std::vector<uint8_t> weights_;
153 const TfLiteIntArray* weight_dims_;
154 std::vector<uint8_t> recurrent_to_input_;
155 std::vector<uint8_t> input_to_input_;
156 std::vector<uint8_t> recurrent_to_cell_;
157 std::vector<uint8_t> input_to_cell_;
158 std::vector<uint8_t> recurrent_to_forget_;
159 std::vector<uint8_t> input_to_forget_;
160 std::vector<uint8_t> recurrent_to_output_;
161 std::vector<uint8_t> input_to_output_;
162 };
163
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToInput)164 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToInput) {
165 DecomposeQuantLstmWeightsTensor(
166 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
167 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
168 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
169
170 EXPECT_THAT(recurrent_to_input_, ElementsAreArray({1, 2, //
171 11, 12}));
172 }
173
TEST_F(QuantLstmWeightDecompTest,ExtractInputToInput)174 TEST_F(QuantLstmWeightDecompTest, ExtractInputToInput) {
175 DecomposeQuantLstmWeightsTensor(
176 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
177 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
178 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
179
180 EXPECT_THAT(input_to_input_, ElementsAreArray({3, 4, 5, //
181 13, 14, 15}));
182 }
183
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToCell)184 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToCell) {
185 DecomposeQuantLstmWeightsTensor(
186 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
187 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
188 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
189
190 EXPECT_THAT(recurrent_to_cell_, ElementsAreArray({101, 102, //
191 111, 112}));
192 }
193
TEST_F(QuantLstmWeightDecompTest,ExtractInputToCell)194 TEST_F(QuantLstmWeightDecompTest, ExtractInputToCell) {
195 DecomposeQuantLstmWeightsTensor(
196 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
197 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
198 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
199
200 EXPECT_THAT(input_to_cell_, ElementsAreArray({103, 104, 105, //
201 113, 114, 115}));
202 }
203
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToForget)204 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToForget) {
205 DecomposeQuantLstmWeightsTensor(
206 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
207 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
208 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
209
210 EXPECT_THAT(recurrent_to_forget_, ElementsAreArray({201, 202, //
211 211, 212}));
212 }
213
TEST_F(QuantLstmWeightDecompTest,ExtractInputToForget)214 TEST_F(QuantLstmWeightDecompTest, ExtractInputToForget) {
215 DecomposeQuantLstmWeightsTensor(
216 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
217 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
218 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
219
220 EXPECT_THAT(input_to_forget_, ElementsAreArray({203, 204, 205, //
221 213, 214, 215}));
222 }
223
TEST_F(QuantLstmWeightDecompTest,ExtractRecurrentToOutput)224 TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToOutput) {
225 DecomposeQuantLstmWeightsTensor(
226 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
227 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
228 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
229
230 EXPECT_THAT(recurrent_to_output_, ElementsAreArray({221, 222, //
231 231, 232}));
232 }
233
TEST_F(QuantLstmWeightDecompTest,ExtractInputToOutput)234 TEST_F(QuantLstmWeightDecompTest, ExtractInputToOutput) {
235 DecomposeQuantLstmWeightsTensor(
236 weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
237 &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
238 &input_to_forget_, &recurrent_to_output_, &input_to_output_);
239
240 EXPECT_THAT(input_to_output_, ElementsAreArray({223, 224, 225, //
241 233, 234, 235}));
242 }
243
244 using tflite::delegate::nnapi::DecomposeBiasTensor;
245
TEST(DecomposeBiasTensor,ExtractInputBias)246 TEST(DecomposeBiasTensor, ExtractInputBias) {
247 // clang-format off
248 std::vector<int32_t> biases
249 // inputGateBias
250 {-7876, 13488, -726, 32839,
251 // cellGateBias
252 39481, 48624, 48976, -21419,
253 // forgetGateBias
254 9206, -46884, -11693, -38724,
255 // outputGateBias
256 -58999, -17050, -41852, -40538};
257 // clang-format on
258
259 std::vector<int32_t> input_bias;
260 std::vector<int32_t> cell_bias;
261 std::vector<int32_t> forget_bias;
262 std::vector<int32_t> output_bias;
263 DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
264 &output_bias);
265
266 EXPECT_THAT(input_bias, ElementsAreArray({-7876, 13488, -726, 32839}));
267 }
268
TEST(DecomposeBiasTensor,ExtractCellBias)269 TEST(DecomposeBiasTensor, ExtractCellBias) {
270 // clang-format off
271 std::vector<int32_t> biases
272 // inputGateBias
273 {-7876, 13488, -726, 32839,
274 // cellGateBias
275 39481, 48624, 48976, -21419,
276 // forgetGateBias
277 9206, -46884, -11693, -38724,
278 // outputGateBias
279 -58999, -17050, -41852, -40538};
280 // clang-format on
281
282 std::vector<int32_t> input_bias;
283 std::vector<int32_t> cell_bias;
284 std::vector<int32_t> forget_bias;
285 std::vector<int32_t> output_bias;
286 DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
287 &output_bias);
288
289 EXPECT_THAT(cell_bias, ElementsAreArray({39481, 48624, 48976, -21419}));
290 }
291
TEST(DecomposeBiasTensor,ExtractForgetBias)292 TEST(DecomposeBiasTensor, ExtractForgetBias) {
293 // clang-format off
294 std::vector<int32_t> biases
295 // inputGateBias
296 {-7876, 13488, -726, 32839,
297 // cellGateBias
298 39481, 48624, 48976, -21419,
299 // forgetGateBias
300 9206, -46884, -11693, -38724,
301 // outputGateBias
302 -58999, -17050, -41852, -40538};
303 // clang-format on
304
305 std::vector<int32_t> input_bias;
306 std::vector<int32_t> cell_bias;
307 std::vector<int32_t> forget_bias;
308 std::vector<int32_t> output_bias;
309 DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
310 &output_bias);
311
312 EXPECT_THAT(forget_bias, ElementsAreArray({9206, -46884, -11693, -38724}));
313 }
314
TEST(DecomposeBiasTensor,ExtractOutputBias)315 TEST(DecomposeBiasTensor, ExtractOutputBias) {
316 // clang-format off
317 std::vector<int32_t> biases
318 // inputGateBias
319 {-7876, 13488, -726, 32839,
320 // cellGateBias
321 39481, 48624, 48976, -21419,
322 // forgetGateBias
323 9206, -46884, -11693, -38724,
324 // outputGateBias
325 -58999, -17050, -41852, -40538};
326 // clang-format on
327
328 std::vector<int32_t> input_bias;
329 std::vector<int32_t> cell_bias;
330 std::vector<int32_t> forget_bias;
331 std::vector<int32_t> output_bias;
332 DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
333 &output_bias);
334
335 EXPECT_THAT(output_bias, ElementsAreArray({-58999, -17050, -41852, -40538}));
336 }
337
338 } // namespace
339