1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16
17 #include <stdint.h>
18 #include <stdlib.h>
19
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23
24 #include <gtest/gtest.h>
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/kernels/cpu_backend_context.h"
28
29 namespace tflite {
30 namespace {
31
32 // Validate result.
33 template <typename T>
ArrayEq(const T * result,const T * expected_result,int size)34 bool ArrayEq(const T* result, const T* expected_result, int size) {
35 for (int i = 0; i < size; ++i) {
36 if (result[i] != expected_result[i]) {
37 return false;
38 }
39 }
40 return true;
41 }
42
43 template <typename T>
ArrayFloatNear(const T * result,const T * expected_result,int size,double threshold)44 bool ArrayFloatNear(const T* result, const T* expected_result, int size,
45 double threshold) {
46 for (int i = 0; i < size; ++i) {
47 if (std::abs(result[i] - expected_result[i]) > threshold) {
48 return false;
49 }
50 }
51 return true;
52 }
53
54 // Base class that holds input parameters for quantized and hybrid lstm.
55 class BaseLstmParam {
56 public:
Geti2i()57 TfLiteTensor* Geti2i() {
58 PackWeightToTensor(&i2i_tensor_, i2i_, i2i_size_);
59 i2i_tensor_.data.int8 = i2i_.data();
60 return &i2i_tensor_;
61 }
Geti2f()62 TfLiteTensor* Geti2f() {
63 PackWeightToTensor(&i2f_tensor_, i2f_, i2f_size_);
64 i2f_tensor_.data.int8 = i2f_.data();
65 return &i2f_tensor_;
66 }
Geti2c()67 TfLiteTensor* Geti2c() {
68 PackWeightToTensor(&i2c_tensor_, i2c_, i2c_size_);
69 i2c_tensor_.data.int8 = i2c_.data();
70 return &i2c_tensor_;
71 }
Geti2o()72 TfLiteTensor* Geti2o() {
73 PackWeightToTensor(&i2o_tensor_, i2o_, i2o_size_);
74 i2o_tensor_.data.int8 = i2o_.data();
75 return &i2o_tensor_;
76 }
Getr2i()77 TfLiteTensor* Getr2i() {
78 PackWeightToTensor(&r2i_tensor_, r2i_, r2i_size_);
79 r2i_tensor_.data.int8 = r2i_.data();
80 return &r2i_tensor_;
81 }
Getr2f()82 TfLiteTensor* Getr2f() {
83 PackWeightToTensor(&r2f_tensor_, r2f_, r2f_size_);
84 r2f_tensor_.data.int8 = r2f_.data();
85 return &r2f_tensor_;
86 }
Getr2c()87 TfLiteTensor* Getr2c() {
88 PackWeightToTensor(&r2c_tensor_, r2c_, r2c_size_);
89 r2c_tensor_.data.int8 = r2c_.data();
90 return &r2c_tensor_;
91 }
Getr2o()92 TfLiteTensor* Getr2o() {
93 PackWeightToTensor(&r2o_tensor_, r2o_, r2o_size_);
94 r2o_tensor_.data.int8 = r2o_.data();
95 return &r2o_tensor_;
96 }
GetProjection()97 TfLiteTensor* GetProjection() {
98 PackWeightToTensor(&projection_tensor_, projection_, projection_size_);
99 projection_tensor_.data.int8 = projection_.data();
100 return &projection_tensor_;
101 }
~BaseLstmParam()102 ~BaseLstmParam() {
103 TfLiteIntArrayFree(input_tensor_.dims);
104 TfLiteIntArrayFree(i2i_tensor_.dims);
105 TfLiteIntArrayFree(i2f_tensor_.dims);
106 TfLiteIntArrayFree(i2c_tensor_.dims);
107 TfLiteIntArrayFree(i2o_tensor_.dims);
108 TfLiteIntArrayFree(r2i_tensor_.dims);
109 TfLiteIntArrayFree(r2f_tensor_.dims);
110 TfLiteIntArrayFree(r2c_tensor_.dims);
111 TfLiteIntArrayFree(r2o_tensor_.dims);
112 TfLiteIntArrayFree(layer_norm_input_tensor_.dims);
113 TfLiteIntArrayFree(layer_norm_forget_tensor_.dims);
114 TfLiteIntArrayFree(layer_norm_cell_tensor_.dims);
115 TfLiteIntArrayFree(layer_norm_output_tensor_.dims);
116 TfLiteIntArrayFree(input_gate_bias_tensor_.dims);
117 TfLiteIntArrayFree(forget_gate_bias_tensor_.dims);
118 TfLiteIntArrayFree(cell_gate_bias_tensor_.dims);
119 TfLiteIntArrayFree(output_gate_bias_tensor_.dims);
120 TfLiteIntArrayFree(projection_tensor_.dims);
121 TfLiteIntArrayFree(projection_bias_tensor_.dims);
122 TfLiteIntArrayFree(activation_tensor_.dims);
123 TfLiteIntArrayFree(cell_tensor_.dims);
124 TfLiteIntArrayFree(output_tensor_.dims);
125 }
126
127 protected:
128 template <typename T>
PackWeightToTensor(TfLiteTensor * tensor,std::vector<T> & data,std::vector<int32_t> dims)129 void PackWeightToTensor(TfLiteTensor* tensor, std::vector<T>& data,
130 std::vector<int32_t> dims) {
131 if (data.empty()) {
132 int total = 1;
133 for (int i = 0; i < dims.size(); ++i) {
134 total *= dims[i];
135 }
136 for (int i = 0; i < total; ++i) {
137 data.push_back(0);
138 }
139 }
140 tensor->dims = TfLiteIntArrayCreate(dims.size());
141 for (int i = 0; i < dims.size(); ++i) {
142 tensor->dims->data[i] = dims[i];
143 }
144 }
145 // Dimensions. Need proper size to trigger neon code.
146 const int n_batch_ = 2;
147 const int n_input_ = 18;
148 const int n_cell_ = 10;
149 const int n_output_ = 6;
150
151 std::vector<int32_t> input_size_ = {n_batch_, n_input_};
152 TfLiteTensor input_tensor_;
153
154 // input_to_input_weights.
155 std::vector<int8_t> i2i_ = {
156 18, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
157 1, 2, 3, 4, 5, 6, 5, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, //
158 8, 2, 3, 4, 3, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, -4, 5, 6, //
159 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, -5, 6, 1, 7, 3, 4, -5, 6, //
160 8, 2, 3, 4, 5, 6, 3, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
161 1, -2, 2, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 8, 5, -6, //
162 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
163 1, 2, 3, 4, 3, 6, 1, 2, 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
164 8, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 1, 2, 3, 14, 5, 6, //
165 1, 2, 3, -4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
166 };
167 std::vector<int32_t> i2i_size_ = {n_cell_, n_input_};
168 TfLiteTensor i2i_tensor_;
169
170 // input_to_forget_weights.
171 std::vector<int8_t> i2f_ = {
172 1, 2, 3, 4, 5, 6, 5, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, //
173 8, 2, 3, 4, 3, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, -4, 5, 6, //
174 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, -5, 6, 1, 7, 3, 4, -5, 6, //
175 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
176 1, 2, 3, 4, 3, 6, 1, 2, 6, 4, 5, 6, 11, 2, 3, 4, -5, 6, //
177 8, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, -6, 1, 2, 3, 14, 5, 6, //
178 1, 2, 3, -4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
179 18, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
180 8, 2, 3, 4, 5, 6, 3, 2, 3, 4, 5, 6, 13, 2, 3, 4, 5, 6, //
181 1, -2, 2, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 8, 5, -6, //
182 };
183 std::vector<int32_t> i2f_size_ = {n_cell_, n_input_};
184 TfLiteTensor i2f_tensor_;
185
186 // input_to_cell_weights.
187 std::vector<int8_t> i2c_ = {
188 1, 2, 3, 4, 5, 6, 5, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, //
189 1, 2, 3, 4, 3, 6, 1, 2, 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
190 8, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 16, 1, 2, 3, 14, 5, 6, //
191 1, 2, 3, -4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, //
192 18, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
193 8, 2, 3, 4, 5, 6, 3, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
194 1, -2, 2, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 8, 5, -6, //
195 8, 2, 3, 4, 3, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, -4, 5, 6, //
196 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, -5, 6, 1, 7, 3, 4, -5, 6, //
197 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
198 };
199 std::vector<int32_t> i2c_size_ = {n_cell_, n_input_};
200 TfLiteTensor i2c_tensor_;
201
202 // input_to_output_weights.
203 std::vector<int8_t> i2o_ = {
204 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, -5, 6, 1, 7, 3, 4, -5, 6, //
205 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, -1, 2, 3, 4, 5, 6, //
206 1, 2, 3, 4, 3, 6, 1, 2, 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
207 8, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 1, 2, 3, 14, 5, 6, //
208 18, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, -6, 1, 2, 3, 4, 5, 6, //
209 8, 2, 3, 4, 5, 6, 3, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
210 1, 2, 3, 4, 5, 6, 5, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, //
211 8, 2, 3, 4, 3, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, -4, 5, 6, //
212 1, 2, 3, -4, 5, 6, 1, 2, 3, 4, 5, 6, -1, 2, 3, 4, 5, 6, //
213 1, -2, 2, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 8, 5, -6, //
214 };
215 std::vector<int32_t> i2o_size_ = {n_cell_, n_input_};
216 TfLiteTensor i2o_tensor_;
217
218 // recurrent_to_input_weights.
219 std::vector<int8_t> r2i_ = {
220 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
221 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
222 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
223 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
224 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
225 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
226 };
227 std::vector<int32_t> r2i_size_ = {n_cell_, n_output_};
228 TfLiteTensor r2i_tensor_;
229
230 // recurrent_to_forget_weights.
231 std::vector<int8_t> r2f_ = {
232 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
233 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
234 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
235 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
236 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
237 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
238 };
239 std::vector<int32_t> r2f_size_ = {n_cell_, n_output_};
240 TfLiteTensor r2f_tensor_;
241
242 // recurrent_to_cell_weights.
243 std::vector<int8_t> r2c_ = {
244 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
245 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
246 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
247 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
248 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
249 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
250 };
251 std::vector<int32_t> r2c_size_ = {n_cell_, n_output_};
252 TfLiteTensor r2c_tensor_;
253
254 // recurrent_to_output_weights.
255 std::vector<int8_t> r2o_ = {
256 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
257 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
258 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
259 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
260 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
261 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
262 };
263 std::vector<int32_t> r2o_size_ = {n_cell_, n_output_};
264 TfLiteTensor r2o_tensor_;
265
266 std::vector<int32_t> layer_norm_input_size_ = {n_cell_};
267 TfLiteTensor layer_norm_input_tensor_;
268
269 TfLiteTensor layer_norm_forget_tensor_;
270 std::vector<int32_t> layer_norm_forget_size_ = {n_cell_};
271
272 std::vector<int32_t> layer_norm_cell_size_ = {n_cell_};
273 TfLiteTensor layer_norm_cell_tensor_;
274
275 std::vector<int32_t> layer_norm_output_size_ = {n_cell_};
276 TfLiteTensor layer_norm_output_tensor_;
277
278 std::vector<int32_t> input_gate_bias_size_ = {n_cell_};
279 TfLiteTensor input_gate_bias_tensor_;
280
281 std::vector<int32_t> forget_gate_bias_size_ = {n_cell_};
282 TfLiteTensor forget_gate_bias_tensor_;
283
284 std::vector<int32_t> cell_gate_bias_size_ = {n_cell_};
285 TfLiteTensor cell_gate_bias_tensor_;
286
287 std::vector<int32_t> output_gate_bias_size_ = {n_cell_};
288 TfLiteTensor output_gate_bias_tensor_;
289
290 // projection_weights.
291 std::vector<int8_t> projection_ = {
292 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
293 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
294 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
295 8, 2, 3, 4, 5, 6, 1, 2, 3, 4, //
296 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
297 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
298 };
299 std::vector<int32_t> projection_size_ = {n_cell_, n_output_};
300 TfLiteTensor projection_tensor_;
301
302 // projection_bias.
303 std::vector<int32_t> projection_bias_ = {
304 16, 4, 5, 6, 1, 1 //
305 };
306
307 std::vector<int32_t> projection_bias_size_ = {n_output_};
308 TfLiteTensor projection_bias_tensor_;
309
310 std::vector<int32_t> activation_size_ = {n_batch_, n_output_};
311 TfLiteTensor activation_tensor_;
312
313 std::vector<int32_t> cell_size_ = {n_batch_, n_cell_};
314 TfLiteTensor cell_tensor_;
315
316 std::vector<int32_t> output_size_ = {n_batch_, n_output_};
317 TfLiteTensor output_tensor_;
318 };
319
320 class QuantizedLstmParam : public BaseLstmParam {
321 public:
322 // Getter methods.
GetInput()323 TfLiteTensor* GetInput() {
324 PackWeightToTensor(&input_tensor_, input_, input_size_);
325 input_tensor_.data.int8 = input_.data();
326 return &input_tensor_;
327 }
GetInputLayerNorm()328 TfLiteTensor* GetInputLayerNorm() {
329 PackWeightToTensor(&layer_norm_input_tensor_, layer_norm_input_,
330 layer_norm_input_size_);
331 layer_norm_input_tensor_.data.i16 = layer_norm_input_.data();
332 return &layer_norm_input_tensor_;
333 }
GetForgetLayerNorm()334 TfLiteTensor* GetForgetLayerNorm() {
335 PackWeightToTensor(&layer_norm_forget_tensor_, layer_norm_forget_,
336 layer_norm_forget_size_);
337 layer_norm_forget_tensor_.data.i16 = layer_norm_forget_.data();
338 return &layer_norm_forget_tensor_;
339 }
GetCellLayerNorm()340 TfLiteTensor* GetCellLayerNorm() {
341 PackWeightToTensor(&layer_norm_cell_tensor_, layer_norm_cell_,
342 layer_norm_cell_size_);
343 layer_norm_cell_tensor_.data.i16 = layer_norm_cell_.data();
344 return &layer_norm_cell_tensor_;
345 }
GetOutputLayerNorm()346 TfLiteTensor* GetOutputLayerNorm() {
347 PackWeightToTensor(&layer_norm_output_tensor_, layer_norm_output_,
348 layer_norm_output_size_);
349 layer_norm_output_tensor_.data.i16 = layer_norm_output_.data();
350 return &layer_norm_output_tensor_;
351 }
GetInputBias()352 TfLiteTensor* GetInputBias() {
353 PackWeightToTensor(&input_gate_bias_tensor_, input_gate_bias_,
354 input_gate_bias_size_);
355 input_gate_bias_tensor_.data.i32 = input_gate_bias_.data();
356 return &input_gate_bias_tensor_;
357 }
GetForgetBias()358 TfLiteTensor* GetForgetBias() {
359 PackWeightToTensor(&forget_gate_bias_tensor_, forget_gate_bias_,
360 forget_gate_bias_size_);
361 forget_gate_bias_tensor_.data.i32 = forget_gate_bias_.data();
362 return &forget_gate_bias_tensor_;
363 }
GetCellBias()364 TfLiteTensor* GetCellBias() {
365 PackWeightToTensor(&cell_gate_bias_tensor_, cell_gate_bias_,
366 cell_gate_bias_size_);
367 cell_gate_bias_tensor_.data.i32 = cell_gate_bias_.data();
368 return &cell_gate_bias_tensor_;
369 }
GetOutputBias()370 TfLiteTensor* GetOutputBias() {
371 PackWeightToTensor(&output_gate_bias_tensor_, output_gate_bias_,
372 output_gate_bias_size_);
373 output_gate_bias_tensor_.data.i32 = output_gate_bias_.data();
374 return &output_gate_bias_tensor_;
375 }
GetProjectionBias()376 TfLiteTensor* GetProjectionBias() {
377 PackWeightToTensor(&projection_bias_tensor_, projection_bias_,
378 projection_bias_size_);
379 projection_bias_tensor_.data.i32 = projection_bias_.data();
380 return &projection_bias_tensor_;
381 }
382
383 // Set up quantization parameters.
GetQuantParam()384 ops::builtin::lstm_eval::IntegerLstmParameter* GetQuantParam() {
385 integer_lstm_param_.effective_input_to_input_scale_a = 1808677632;
386 integer_lstm_param_.effective_input_to_input_scale_b = -1;
387 integer_lstm_param_.effective_recurrent_to_input_scale_a = 1078887680;
388 integer_lstm_param_.effective_recurrent_to_input_scale_b = -1;
389 integer_lstm_param_.effective_cell_to_input_scale_a = 1073741824;
390 integer_lstm_param_.effective_cell_to_input_scale_b = 1;
391 integer_lstm_param_.effective_input_to_forget_scale_a = 1845996800;
392 integer_lstm_param_.effective_input_to_forget_scale_b = -3;
393 integer_lstm_param_.effective_recurrent_to_forget_scale_a = 1477412736;
394 integer_lstm_param_.effective_recurrent_to_forget_scale_b = -2;
395 integer_lstm_param_.effective_cell_to_forget_scale_a = 1073741824;
396 integer_lstm_param_.effective_cell_to_forget_scale_b = 1;
397 integer_lstm_param_.effective_input_to_cell_scale_a = 1648385408;
398 integer_lstm_param_.effective_input_to_cell_scale_b = -2;
399 integer_lstm_param_.effective_recurrent_to_cell_scale_a = 1185544192,
400 integer_lstm_param_.effective_recurrent_to_cell_scale_b = -1;
401 integer_lstm_param_.effective_input_to_output_scale_a = 1328153600;
402 integer_lstm_param_.effective_input_to_output_scale_b = -1;
403 integer_lstm_param_.effective_recurrent_to_output_scale_a = 1479582592;
404 integer_lstm_param_.effective_recurrent_to_output_scale_b = -1;
405 integer_lstm_param_.effective_cell_to_output_scale_a = 1073741824,
406 integer_lstm_param_.effective_cell_to_output_scale_b = 1;
407 integer_lstm_param_.effective_proj_scale_a = 1105682560;
408 integer_lstm_param_.effective_proj_scale_b = -8;
409 integer_lstm_param_.effective_hidden_scale_a = 0;
410 integer_lstm_param_.effective_hidden_scale_b = 0;
411 integer_lstm_param_.layer_norm_input_scale_a = 2011617664;
412 integer_lstm_param_.layer_norm_input_scale_b = -11;
413 integer_lstm_param_.layer_norm_forget_scale_a = 1968024960;
414 integer_lstm_param_.layer_norm_forget_scale_b = -13;
415 integer_lstm_param_.layer_norm_cell_scale_a = 1097334528,
416 integer_lstm_param_.layer_norm_cell_scale_b = -12;
417 integer_lstm_param_.layer_norm_output_scale_a = 1837163008;
418 integer_lstm_param_.layer_norm_output_scale_b = -12;
419 integer_lstm_param_.quantized_cell_clip = 20480;
420 integer_lstm_param_.quantized_proj_clip = 0;
421 integer_lstm_param_.cell_scale = -11;
422 integer_lstm_param_.input_variance_guard = 1;
423 integer_lstm_param_.forget_variance_guard = 2;
424 integer_lstm_param_.cell_variance_guard = 2;
425 integer_lstm_param_.output_variance_guard = 1;
426 integer_lstm_param_.hidden_zp = 0;
427 integer_lstm_param_.input_to_forget_effective_bias.reset(
428 new int32_t[n_cell_]);
429 integer_lstm_param_.recurrent_to_forget_effective_bias.reset(
430 new int32_t[n_cell_]);
431 integer_lstm_param_.input_to_cell_effective_bias.reset(
432 new int32_t[n_cell_]);
433 integer_lstm_param_.recurrent_to_cell_effective_bias.reset(
434 new int32_t[n_cell_]);
435 integer_lstm_param_.input_to_output_effective_bias.reset(
436 new int32_t[n_cell_]);
437 integer_lstm_param_.recurrent_to_output_effective_bias.reset(
438 new int32_t[n_cell_]);
439 integer_lstm_param_.input_to_input_effective_bias.reset(
440 new int32_t[n_cell_]);
441 integer_lstm_param_.recurrent_to_input_effective_bias.reset(
442 new int32_t[n_cell_]);
443 integer_lstm_param_.projection_effective_bias.reset(new int32_t[n_output_]);
444 std::fill_n(integer_lstm_param_.input_to_forget_effective_bias.get(),
445 n_cell_, 152);
446 std::fill_n(integer_lstm_param_.recurrent_to_forget_effective_bias.get(),
447 n_cell_, 315);
448 std::fill_n(integer_lstm_param_.input_to_cell_effective_bias.get(), n_cell_,
449 165);
450 std::fill_n(integer_lstm_param_.recurrent_to_cell_effective_bias.get(),
451 n_cell_, 1165);
452 std::fill_n(integer_lstm_param_.input_to_output_effective_bias.get(),
453 n_cell_, 159);
454 std::fill_n(integer_lstm_param_.recurrent_to_output_effective_bias.get(),
455 n_cell_, 915);
456 std::fill_n(integer_lstm_param_.input_to_input_effective_bias.get(),
457 n_cell_, -15);
458 std::fill_n(integer_lstm_param_.recurrent_to_input_effective_bias.get(),
459 n_cell_, 315);
460 std::fill_n(integer_lstm_param_.projection_effective_bias.get(), n_output_,
461 115);
462 return &integer_lstm_param_;
463 }
464
465 // Create scratch buffers.
GetScratch0()466 TfLiteTensor* GetScratch0() {
467 PackWeightToTensor(&scratch0_tensor_, scratch0_, scratch0_size_);
468 scratch0_tensor_.data.i16 = scratch0_.data();
469 return &scratch0_tensor_;
470 }
GetScratch1()471 TfLiteTensor* GetScratch1() {
472 PackWeightToTensor(&scratch1_tensor_, scratch1_, scratch1_size_);
473 scratch1_tensor_.data.i16 = scratch1_.data();
474 return &scratch1_tensor_;
475 }
GetScratch2()476 TfLiteTensor* GetScratch2() {
477 PackWeightToTensor(&scratch2_tensor_, scratch2_, scratch2_size_);
478 scratch2_tensor_.data.i16 = scratch2_.data();
479 return &scratch2_tensor_;
480 }
GetScratch3()481 TfLiteTensor* GetScratch3() {
482 PackWeightToTensor(&scratch3_tensor_, scratch3_, scratch3_size_);
483 scratch3_tensor_.data.i16 = scratch3_.data();
484 return &scratch3_tensor_;
485 }
GetScratch4()486 TfLiteTensor* GetScratch4() {
487 PackWeightToTensor(&scratch4_tensor_, scratch4_, scratch4_size_);
488 scratch4_tensor_.data.int8 = scratch4_.data();
489 return &scratch4_tensor_;
490 }
GetScratch5()491 TfLiteTensor* GetScratch5() {
492 PackWeightToTensor(&scratch5_tensor_, scratch5_, scratch5_size_);
493 scratch5_tensor_.data.i32 = scratch5_.data();
494 return &scratch5_tensor_;
495 }
GetActivation()496 TfLiteTensor* GetActivation() {
497 PackWeightToTensor(&activation_tensor_, activation_, activation_size_);
498 activation_tensor_.data.int8 = activation_.data();
499 activation_tensor_.params.zero_point = 50;
500 return &activation_tensor_;
501 }
GetOutput()502 TfLiteTensor* GetOutput() {
503 PackWeightToTensor(&output_tensor_, output_, output_size_);
504 output_tensor_.data.int8 = output_.data();
505 return &output_tensor_;
506 }
GetCell()507 TfLiteTensor* GetCell() {
508 PackWeightToTensor(&cell_tensor_, cell_, cell_size_);
509 cell_tensor_.data.i16 = cell_.data();
510 return &cell_tensor_;
511 }
~QuantizedLstmParam()512 ~QuantizedLstmParam() {
513 TfLiteIntArrayFree(scratch0_tensor_.dims);
514 TfLiteIntArrayFree(scratch1_tensor_.dims);
515 TfLiteIntArrayFree(scratch2_tensor_.dims);
516 TfLiteIntArrayFree(scratch3_tensor_.dims);
517 TfLiteIntArrayFree(scratch4_tensor_.dims);
518 TfLiteIntArrayFree(scratch5_tensor_.dims);
519 }
520
521 private:
522 // input.
523 std::vector<int8_t> input_ = {
524 8, 2, 3, 4, 5, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
525 1, 2, -3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, //
526 };
527
528 std::vector<int16_t> layer_norm_input_ = {8, 2, 3, 4, 5, 6, 1, 2, 3, 4};
529
530 // forget_layer_norm_coefficient.
531 std::vector<int16_t> layer_norm_forget_ = {
532 1, 2, 3, 4, 7, 3, 4, -5, 6, 3, //
533 };
534
535 // cell_layer_norm_coefficients.
536 std::vector<int16_t> layer_norm_cell_ = {
537 6, 4, 5, 6, 1, 2, 3, 4, -5, 6, //
538 };
539
540 // output_layer_norm_coefficients.
541 std::vector<int16_t> layer_norm_output_ = {
542 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
543 };
544
545 // input_gate_bias.
546 std::vector<int32_t> input_gate_bias_ = {
547 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
548 };
549
550 // forget_gate_bias.
551 std::vector<int32_t> forget_gate_bias_ = {
552 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
553 };
554
555 // cell_gate_bias.
556 std::vector<int32_t> cell_gate_bias_ = {
557 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
558 };
559
560 // output_gate_bias.
561 std::vector<int32_t> output_gate_bias_ = {
562 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
563 };
564
565 // activation.
566 std::vector<int8_t> activation_;
567
568 // cell.
569 std::vector<int16_t> cell_ = {
570 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
571 1, 14, 5, 6, 1, 1, 3, 4, -5, 6, //
572 };
573
574 // output.
575 std::vector<int8_t> output_ = {
576 1, 1, 3, 4, -5, 6, //
577 1, 4, 3, 4, -5, 6, //
578 };
579
580 // quantized_lstm_param
581 ops::builtin::lstm_eval::IntegerLstmParameter integer_lstm_param_;
582
583 // 5 scratch buffers.
584 std::vector<int16_t> scratch0_;
585 std::vector<int32_t> scratch0_size_ = {n_batch_, n_cell_};
586 TfLiteTensor scratch0_tensor_;
587 std::vector<int16_t> scratch1_;
588 std::vector<int32_t> scratch1_size_ = {n_batch_, n_cell_};
589 TfLiteTensor scratch1_tensor_;
590 std::vector<int16_t> scratch2_;
591 std::vector<int32_t> scratch2_size_ = {n_batch_, n_cell_};
592 TfLiteTensor scratch2_tensor_;
593 std::vector<int16_t> scratch3_;
594 std::vector<int32_t> scratch3_size_ = {n_batch_, n_cell_};
595 TfLiteTensor scratch3_tensor_;
596 std::vector<int8_t> scratch4_;
597 std::vector<int32_t> scratch4_size_ = {n_batch_, n_cell_};
598 TfLiteTensor scratch4_tensor_;
599 std::vector<int32_t> scratch5_;
600 std::vector<int32_t> scratch5_size_ = {n_batch_, n_cell_};
601 TfLiteTensor scratch5_tensor_;
602 };
603
TestOneFullyQuantizedLSTM()604 void TestOneFullyQuantizedLSTM() {
605 CpuBackendContext context;
606 QuantizedLstmParam one_parameter;
607 auto activation = one_parameter.GetActivation();
608 auto output = one_parameter.GetOutput();
609 auto cell = one_parameter.GetCell();
610 auto param = one_parameter.GetQuantParam();
611 ops::builtin::lstm_eval::EvalInteger8x8_16(
612 one_parameter.GetInput(), one_parameter.Geti2i(), one_parameter.Geti2f(),
613 one_parameter.Geti2c(), one_parameter.Geti2o(), one_parameter.Getr2i(),
614 one_parameter.Getr2f(), one_parameter.Getr2c(), one_parameter.Getr2o(),
615 nullptr, nullptr, nullptr, one_parameter.GetInputLayerNorm(),
616 one_parameter.GetForgetLayerNorm(), one_parameter.GetCellLayerNorm(),
617 one_parameter.GetOutputLayerNorm(), one_parameter.GetInputBias(),
618 one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
619 one_parameter.GetOutputBias(), one_parameter.GetProjection(),
620 one_parameter.GetProjectionBias(), nullptr, /*forward_sequence=*/true,
621 /*time_major=*/true, param, activation, cell, output,
622 one_parameter.GetScratch0(), one_parameter.GetScratch1(),
623 one_parameter.GetScratch2(), one_parameter.GetScratch3(),
624 one_parameter.GetScratch4(), one_parameter.GetScratch5(), &context);
625
626 // Verify results.
627 const std::vector<int16_t> expected_cell = {
628 7, 1, 3, 2, 0, 1, 0, 2, -2, 4, 1, 6, 4, 3, 0, 1, 0, 2, -2, 4,
629 };
630 const std::vector<int8_t> expected_activation = {
631 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
632 };
633 EXPECT_TRUE(ArrayEq(cell->data.i16, expected_cell.data(), 20));
634 EXPECT_TRUE(ArrayEq(activation->data.int8, expected_activation.data(), 12));
635 EXPECT_TRUE(ArrayEq(output->data.int8, expected_activation.data(), 12));
636 }
637
TEST(TestOneFullyQuantizedLSTM,TestOneFullyQuantizedLSTM)638 TEST(TestOneFullyQuantizedLSTM, TestOneFullyQuantizedLSTM) {
639 TestOneFullyQuantizedLSTM();
640 }
641
642 class HybridLstmParam : public BaseLstmParam {
643 public:
GetFloatOutput()644 TfLiteTensor* GetFloatOutput() {
645 PackWeightToTensor(&output_tensor_, output_float_, output_size_);
646 output_tensor_.data.f = output_float_.data();
647 return &output_tensor_;
648 }
GetLSTMParam()649 const TfLiteLSTMParams GetLSTMParam() {
650 return {kTfLiteActRelu, 0, 0, kTfLiteLSTMFullKernel, true};
651 }
GetScratchBuffer()652 TfLiteTensor* GetScratchBuffer() {
653 PackWeightToTensor(&scratch_buffer_tensor_, scratch_buffer_,
654 scratch_buffer_size_);
655 scratch_buffer_tensor_.data.f = scratch_buffer_.data();
656 return &scratch_buffer_tensor_;
657 }
GetInputScalingFactors()658 TfLiteTensor* GetInputScalingFactors() {
659 PackWeightToTensor(&input_sf_tensor_, input_sf_,
660 quantization_extra_scratch_buffer_sizes_);
661 input_sf_tensor_.data.f = input_sf_.data();
662 return &input_sf_tensor_;
663 }
GetAuxInputScalingFactors()664 TfLiteTensor* GetAuxInputScalingFactors() {
665 PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_,
666 quantization_extra_scratch_buffer_sizes_);
667 aux_input_sf_tensor_.data.f = aux_input_sf_.data();
668 return &aux_input_sf_tensor_;
669 }
GetOutputStateScalingFactors()670 TfLiteTensor* GetOutputStateScalingFactors() {
671 PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_,
672 quantization_extra_scratch_buffer_sizes_);
673 output_state_sf_tensor_.data.f = output_state_sf_.data();
674 return &output_state_sf_tensor_;
675 }
GetProdScalingFactors()676 TfLiteTensor* GetProdScalingFactors() {
677 PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_,
678 quantization_extra_scratch_buffer_sizes_);
679 prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data();
680 return &prod_scaling_factors_tensor_;
681 }
GetInputQuantized()682 TfLiteTensor* GetInputQuantized() {
683 PackWeightToTensor(&input_quantized_tensor_, input_quantized_, input_size_);
684 input_quantized_tensor_.data.int8 = input_quantized_.data();
685 return &input_quantized_tensor_;
686 }
GetActivationStateQuantized()687 TfLiteTensor* GetActivationStateQuantized() {
688 PackWeightToTensor(&activation_quantized_tensor_, activation_quantized_,
689 activation_size_);
690 activation_quantized_tensor_.data.int8 = activation_quantized_.data();
691 return &activation_quantized_tensor_;
692 }
GetCellStateQuantized()693 TfLiteTensor* GetCellStateQuantized() {
694 PackWeightToTensor(&cell_quantized_tensor_, cell_quantized_, cell_size_);
695 cell_quantized_tensor_.data.int8 = cell_quantized_.data();
696 return &cell_quantized_tensor_;
697 }
GetInputZeroPoints()698 TfLiteTensor* GetInputZeroPoints() {
699 PackWeightToTensor(&input_zp_tensor_, input_zp_,
700 quantization_extra_scratch_buffer_sizes_);
701 input_zp_tensor_.data.i32 = input_zp_.data();
702 return &input_zp_tensor_;
703 }
GetAuxInputZeroPoints()704 TfLiteTensor* GetAuxInputZeroPoints() {
705 PackWeightToTensor(&aux_input_zp_tensor_, aux_input_zp_,
706 quantization_extra_scratch_buffer_sizes_);
707 aux_input_zp_tensor_.data.i32 = aux_input_zp_.data();
708 return &aux_input_zp_tensor_;
709 }
GetOutputStateZeroPoints()710 TfLiteTensor* GetOutputStateZeroPoints() {
711 PackWeightToTensor(&output_state_zp_tensor_, output_state_zp_,
712 quantization_extra_scratch_buffer_sizes_);
713 output_state_zp_tensor_.data.i32 = output_state_zp_.data();
714 return &output_state_zp_tensor_;
715 }
GetRowSums()716 TfLiteTensor* GetRowSums() {
717 PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_);
718 row_sums_tensor_.data.i32 = row_sums_.data();
719 return &row_sums_tensor_;
720 }
GetFloatInput()721 TfLiteTensor* GetFloatInput() {
722 PackWeightToTensor(&input_tensor_, input_float_, input_size_);
723 input_tensor_.data.f = input_float_.data();
724 return &input_tensor_;
725 }
GetActivation()726 TfLiteTensor* GetActivation() {
727 PackWeightToTensor(&activation_tensor_, activation_state_,
728 activation_size_);
729 activation_tensor_.data.f = activation_state_.data();
730 return &activation_tensor_;
731 }
GetCell()732 TfLiteTensor* GetCell() {
733 PackWeightToTensor(&cell_tensor_, cell_state_, cell_size_);
734 cell_tensor_.data.f = cell_state_.data();
735 return &cell_tensor_;
736 }
GetAccumScratchBuffer()737 TfLiteTensor* GetAccumScratchBuffer() {
738 PackWeightToTensor(&accum_scratch_tensor_, accum_scratch_,
739 accum_scratch_size_);
740 accum_scratch_tensor_.data.i32 = accum_scratch_.data();
741 return &accum_scratch_tensor_;
742 }
GetInputBias()743 TfLiteTensor* GetInputBias() {
744 PackWeightToTensor(&input_gate_bias_tensor_, input_float_bias_,
745 input_gate_bias_size_);
746 input_gate_bias_tensor_.data.f = input_float_bias_.data();
747 return &input_gate_bias_tensor_;
748 }
GetForgetBias()749 TfLiteTensor* GetForgetBias() {
750 PackWeightToTensor(&forget_gate_bias_tensor_, forget_float_bias_,
751 forget_gate_bias_size_);
752 forget_gate_bias_tensor_.data.f = forget_float_bias_.data();
753 return &forget_gate_bias_tensor_;
754 }
GetCellBias()755 TfLiteTensor* GetCellBias() {
756 PackWeightToTensor(&cell_gate_bias_tensor_, cell_float_bias_,
757 cell_gate_bias_size_);
758 cell_gate_bias_tensor_.data.f = cell_float_bias_.data();
759 return &cell_gate_bias_tensor_;
760 }
GetOutputBias()761 TfLiteTensor* GetOutputBias() {
762 PackWeightToTensor(&output_gate_bias_tensor_, output_float_bias_,
763 output_gate_bias_size_);
764 output_gate_bias_tensor_.data.f = output_float_bias_.data();
765 return &output_gate_bias_tensor_;
766 }
GetProjectionBias()767 TfLiteTensor* GetProjectionBias() {
768 PackWeightToTensor(&projection_bias_tensor_, projection_float_bias_,
769 projection_bias_size_);
770 projection_bias_tensor_.data.f = projection_float_bias_.data();
771 return &projection_bias_tensor_;
772 }
GetNumRowSums()773 int GetNumRowSums() { return n_row_sums_; }
GetInputLayerNorm()774 TfLiteTensor* GetInputLayerNorm() {
775 PackWeightToTensor(&layer_norm_input_tensor_, layer_norm_float_input_,
776 layer_norm_input_size_);
777 layer_norm_input_tensor_.data.f = layer_norm_float_input_.data();
778 return &layer_norm_input_tensor_;
779 }
GetForgetLayerNorm()780 TfLiteTensor* GetForgetLayerNorm() {
781 PackWeightToTensor(&layer_norm_forget_tensor_, layer_norm_float_forget_,
782 layer_norm_forget_size_);
783 layer_norm_forget_tensor_.data.f = layer_norm_float_forget_.data();
784 return &layer_norm_forget_tensor_;
785 }
GetCellLayerNorm()786 TfLiteTensor* GetCellLayerNorm() {
787 PackWeightToTensor(&layer_norm_cell_tensor_, layer_norm_float_cell_,
788 layer_norm_cell_size_);
789 layer_norm_cell_tensor_.data.f = layer_norm_float_cell_.data();
790 return &layer_norm_cell_tensor_;
791 }
GetOutputLayerNorm()792 TfLiteTensor* GetOutputLayerNorm() {
793 PackWeightToTensor(&layer_norm_output_tensor_, layer_norm_float_output_,
794 layer_norm_output_size_);
795 layer_norm_output_tensor_.data.f = layer_norm_float_output_.data();
796 return &layer_norm_output_tensor_;
797 }
addScale(TfLiteTensor * t,float scale)798 static TfLiteTensor* addScale(TfLiteTensor* t, float scale) {
799 t->params.scale = scale;
800 return t;
801 }
~HybridLstmParam()802 ~HybridLstmParam() {
803 TfLiteIntArrayFree(scratch_buffer_tensor_.dims);
804 TfLiteIntArrayFree(accum_scratch_tensor_.dims);
805 TfLiteIntArrayFree(input_sf_tensor_.dims);
806 TfLiteIntArrayFree(aux_input_sf_tensor_.dims);
807 TfLiteIntArrayFree(output_state_sf_tensor_.dims);
808 TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims);
809 TfLiteIntArrayFree(input_quantized_tensor_.dims);
810 TfLiteIntArrayFree(activation_quantized_tensor_.dims);
811 TfLiteIntArrayFree(cell_quantized_tensor_.dims);
812 TfLiteIntArrayFree(input_zp_tensor_.dims);
813 TfLiteIntArrayFree(aux_input_zp_tensor_.dims);
814 TfLiteIntArrayFree(output_state_zp_tensor_.dims);
815 TfLiteIntArrayFree(row_sums_tensor_.dims);
816 }
817
818 private:
819 const int n_row_sums_ = 9; // Number of weights + 1 for projection weights.
820
821 std::vector<float> scratch_buffer_;
822 std::vector<int32_t> scratch_buffer_size_ = {n_batch_, n_cell_ * 4};
823 TfLiteTensor scratch_buffer_tensor_;
824
825 std::vector<int32_t> quantization_extra_scratch_buffer_sizes_ = {n_batch_};
826 std::vector<float> input_sf_;
827 TfLiteTensor input_sf_tensor_;
828 std::vector<float> aux_input_sf_;
829 TfLiteTensor aux_input_sf_tensor_;
830 std::vector<float> output_state_sf_;
831 TfLiteTensor output_state_sf_tensor_;
832
833 std::vector<float> prod_scaling_factors_;
834 TfLiteTensor prod_scaling_factors_tensor_;
835
836 std::vector<int32_t> input_zp_;
837 TfLiteTensor input_zp_tensor_;
838 std::vector<int32_t> aux_input_zp_;
839 TfLiteTensor aux_input_zp_tensor_;
840 std::vector<int32_t> output_state_zp_;
841 TfLiteTensor output_state_zp_tensor_;
842
843 std::vector<int8_t> input_quantized_;
844 TfLiteTensor input_quantized_tensor_;
845
846 std::vector<int8_t> activation_quantized_;
847 TfLiteTensor activation_quantized_tensor_;
848
849 std::vector<int8_t> cell_quantized_;
850 TfLiteTensor cell_quantized_tensor_;
851
852 std::vector<float> cell_state_ = {
853 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6,
854 };
855
856 std::vector<int32_t> row_sums_;
857 std::vector<int32_t> row_sums_size_ = {n_row_sums_, n_cell_};
858 TfLiteTensor row_sums_tensor_;
859
860 std::vector<float> activation_state_;
861
862 std::vector<int32_t> accum_scratch_;
863 std::vector<int32_t> accum_scratch_size_ = {n_cell_, n_batch_};
864 TfLiteTensor accum_scratch_tensor_;
865 std::vector<float> output_float_ = {
866 1, 1, 3, 4, -5, 6, //
867 1, 4, 3, 4, -5, 6, //
868 };
869 std::vector<float> input_float_ = {
870 6.06, 7.66, 7.10, 9.32, 3.85, 0.33, 7.15, 1.56, 9.54,
871 5.30, 4.53, 0.19, 1.83, 4.60, 0.84, 5.08, 4.37, 9.92, //
872 4.08, 3.79, 1.17, 8.99, 0.14, 9.22, 3.18, 2.97, 7.53,
873 0.59, 9.89, 9.13, 7.68, 0.63, 2.15, 4.31, 7.20, 4.09, //
874 };
875 std::vector<float> input_float_bias_ = {
876 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
877 };
878 std::vector<float> forget_float_bias_ = {
879 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
880 };
881 std::vector<float> cell_float_bias_ = {
882 -11, -7, -4, -5, -1, -1, -2, -3.5, -3, -4,
883 };
884 std::vector<float> output_float_bias_ = {0.16, 0.4, 0.5, 0.6, 0.1,
885 0.1, 0.3, 0.4, -0.5, 0.6};
886 std::vector<float> projection_float_bias_ = {0, 0, 0, 0, 0, 0};
887 std::vector<float> layer_norm_float_input_ = {8, 2, 3, 4, 5, 6, 1, -2, 3, 4};
888 std::vector<float> layer_norm_float_forget_ = {
889 0.1, 0.2, 0.3, 0.4, 0.7, 0.3, 0.4, -0.5, 0.6, 0.3, //
890 };
891 std::vector<float> layer_norm_float_cell_ = {
892 0.6, 0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.4, -0.5, 0.6, //
893 };
894 std::vector<float> layer_norm_float_output_ = {
895 0.6, 0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.4, -0.5, 0.6, //
896 };
897 };
898
TestOneHybridAsymmLSTM()899 void TestOneHybridAsymmLSTM() {
900 CpuBackendContext context;
901 HybridLstmParam one_parameter;
902 auto activation = one_parameter.GetActivation();
903 auto output = one_parameter.GetFloatOutput();
904 auto cell = one_parameter.GetCell();
905 auto param = one_parameter.GetLSTMParam();
906 bool compute_row_sums = true;
907 constexpr float kDefaultScale = 18.0;
908 ops::builtin::lstm_eval::EvalHybrid(
909 one_parameter.GetFloatInput(),
910 HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr,
911 HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr,
912 HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr,
913 HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr,
914 HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr,
915 HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr,
916 HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr,
917 HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr,
918 /*cell_to_input_weights=*/nullptr,
919 /*cell_to_forget_weights=*/nullptr,
920 /*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
921 one_parameter.GetForgetLayerNorm(), one_parameter.GetCellLayerNorm(),
922 one_parameter.GetOutputLayerNorm(),
923 /*aux_input=*/nullptr,
924 /*aux_input_to_input_weights=*/nullptr,
925 /*aux_input_to_forget_weights=*/nullptr,
926 /*aux_input_to_cell_weights=*/nullptr,
927 /*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
928 one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
929 one_parameter.GetOutputBias(),
930 HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr,
931 one_parameter.GetProjectionBias(), ¶m,
932 /*forward_sequence=*/true,
933 /*time_major=*/true,
934 /*output_offset=*/0, one_parameter.GetScratchBuffer(),
935 one_parameter.GetInputScalingFactors(),
936 one_parameter.GetAuxInputScalingFactors(),
937 one_parameter.GetOutputStateScalingFactors(),
938 one_parameter.GetProdScalingFactors(),
939 /*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(),
940 /*aux_input_quantized=*/nullptr,
941 one_parameter.GetActivationStateQuantized(),
942 one_parameter.GetCellStateQuantized(), activation, cell,
943 one_parameter.GetAccumScratchBuffer(), output,
944 one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(),
945 one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(),
946 one_parameter.GetNumRowSums(), &compute_row_sums, &context);
947 const std::vector<float> expected_cell = {
948 7.83134, 1.96158, 2.18285, 3.28739, 0.483214,
949 0.618206, 1.21539, 1.4052, -3.17735, 2.24296, //
950 0.498944, 6.91104, 1.74126, 3.28993, 0.580477,
951 0.489936, 1.2527, 1.50157, -3.71849, 2.76743, //
952 };
953 const std::vector<float> expected_activation = {
954 53.0403, 59.3623, 24.8493, 53.0403, 59.3623, 24.8493, //
955 36.7559, 57.5202, 29.7217, 36.7559, 57.5202, 29.7217,
956 };
957 EXPECT_TRUE(ArrayFloatNear(cell->data.f, expected_cell.data(), 20, 1e-2));
958 EXPECT_TRUE(
959 ArrayFloatNear(activation->data.f, expected_activation.data(), 12, 1e-4));
960 EXPECT_TRUE(
961 ArrayFloatNear(output->data.f, expected_activation.data(), 12, 1e-4));
962 }
963
TEST(TestOneHybridAsymmLSTM,TestOneHybridAsymmLSTM)964 TEST(TestOneHybridAsymmLSTM, TestOneHybridAsymmLSTM) {
965 TestOneHybridAsymmLSTM();
966 }
967
968 } // namespace
969 } // namespace tflite
970