• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16 
17 #include <stdint.h>
18 #include <stdlib.h>
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 
24 #include <gtest/gtest.h>
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/kernels/cpu_backend_context.h"
28 
29 namespace tflite {
30 namespace {
31 
32 // Validate result.
33 template <typename T>
ArrayEq(const T * result,const T * expected_result,int size)34 bool ArrayEq(const T* result, const T* expected_result, int size) {
35   for (int i = 0; i < size; ++i) {
36     if (result[i] != expected_result[i]) {
37       return false;
38     }
39   }
40   return true;
41 }
42 
43 template <typename T>
ArrayFloatNear(const T * result,const T * expected_result,int size,double threshold)44 bool ArrayFloatNear(const T* result, const T* expected_result, int size,
45                     double threshold) {
46   for (int i = 0; i < size; ++i) {
47     if (std::abs(result[i] - expected_result[i]) > threshold) {
48       return false;
49     }
50   }
51   return true;
52 }
53 
54 // Base class that holds input parameters for quantized and hybrid lstm.
55 class BaseLstmParam {
56  public:
Geti2i()57   TfLiteTensor* Geti2i() {
58     PackWeightToTensor(&i2i_tensor_, i2i_, i2i_size_);
59     i2i_tensor_.data.int8 = i2i_.data();
60     return &i2i_tensor_;
61   }
Geti2f()62   TfLiteTensor* Geti2f() {
63     PackWeightToTensor(&i2f_tensor_, i2f_, i2f_size_);
64     i2f_tensor_.data.int8 = i2f_.data();
65     return &i2f_tensor_;
66   }
Geti2c()67   TfLiteTensor* Geti2c() {
68     PackWeightToTensor(&i2c_tensor_, i2c_, i2c_size_);
69     i2c_tensor_.data.int8 = i2c_.data();
70     return &i2c_tensor_;
71   }
Geti2o()72   TfLiteTensor* Geti2o() {
73     PackWeightToTensor(&i2o_tensor_, i2o_, i2o_size_);
74     i2o_tensor_.data.int8 = i2o_.data();
75     return &i2o_tensor_;
76   }
Getr2i()77   TfLiteTensor* Getr2i() {
78     PackWeightToTensor(&r2i_tensor_, r2i_, r2i_size_);
79     r2i_tensor_.data.int8 = r2i_.data();
80     return &r2i_tensor_;
81   }
Getr2f()82   TfLiteTensor* Getr2f() {
83     PackWeightToTensor(&r2f_tensor_, r2f_, r2f_size_);
84     r2f_tensor_.data.int8 = r2f_.data();
85     return &r2f_tensor_;
86   }
Getr2c()87   TfLiteTensor* Getr2c() {
88     PackWeightToTensor(&r2c_tensor_, r2c_, r2c_size_);
89     r2c_tensor_.data.int8 = r2c_.data();
90     return &r2c_tensor_;
91   }
Getr2o()92   TfLiteTensor* Getr2o() {
93     PackWeightToTensor(&r2o_tensor_, r2o_, r2o_size_);
94     r2o_tensor_.data.int8 = r2o_.data();
95     return &r2o_tensor_;
96   }
GetProjection()97   TfLiteTensor* GetProjection() {
98     PackWeightToTensor(&projection_tensor_, projection_, projection_size_);
99     projection_tensor_.data.int8 = projection_.data();
100     return &projection_tensor_;
101   }
~BaseLstmParam()102   ~BaseLstmParam() {
103     TfLiteIntArrayFree(input_tensor_.dims);
104     TfLiteIntArrayFree(i2i_tensor_.dims);
105     TfLiteIntArrayFree(i2f_tensor_.dims);
106     TfLiteIntArrayFree(i2c_tensor_.dims);
107     TfLiteIntArrayFree(i2o_tensor_.dims);
108     TfLiteIntArrayFree(r2i_tensor_.dims);
109     TfLiteIntArrayFree(r2f_tensor_.dims);
110     TfLiteIntArrayFree(r2c_tensor_.dims);
111     TfLiteIntArrayFree(r2o_tensor_.dims);
112     TfLiteIntArrayFree(layer_norm_input_tensor_.dims);
113     TfLiteIntArrayFree(layer_norm_forget_tensor_.dims);
114     TfLiteIntArrayFree(layer_norm_cell_tensor_.dims);
115     TfLiteIntArrayFree(layer_norm_output_tensor_.dims);
116     TfLiteIntArrayFree(input_gate_bias_tensor_.dims);
117     TfLiteIntArrayFree(forget_gate_bias_tensor_.dims);
118     TfLiteIntArrayFree(cell_gate_bias_tensor_.dims);
119     TfLiteIntArrayFree(output_gate_bias_tensor_.dims);
120     TfLiteIntArrayFree(projection_tensor_.dims);
121     TfLiteIntArrayFree(projection_bias_tensor_.dims);
122     TfLiteIntArrayFree(activation_tensor_.dims);
123     TfLiteIntArrayFree(cell_tensor_.dims);
124     TfLiteIntArrayFree(output_tensor_.dims);
125   }
126 
127  protected:
128   template <typename T>
PackWeightToTensor(TfLiteTensor * tensor,std::vector<T> & data,std::vector<int32_t> dims)129   void PackWeightToTensor(TfLiteTensor* tensor, std::vector<T>& data,
130                           std::vector<int32_t> dims) {
131     if (data.empty()) {
132       int total = 1;
133       for (int i = 0; i < dims.size(); ++i) {
134         total *= dims[i];
135       }
136       for (int i = 0; i < total; ++i) {
137         data.push_back(0);
138       }
139     }
140     tensor->dims = TfLiteIntArrayCreate(dims.size());
141     for (int i = 0; i < dims.size(); ++i) {
142       tensor->dims->data[i] = dims[i];
143     }
144   }
145   // Dimensions. Need proper size to trigger neon code.
146   const int n_batch_ = 2;
147   const int n_input_ = 18;
148   const int n_cell_ = 10;
149   const int n_output_ = 6;
150 
151   std::vector<int32_t> input_size_ = {n_batch_, n_input_};
152   TfLiteTensor input_tensor_;
153 
154   // input_to_input_weights.
155   std::vector<int8_t> i2i_ = {
156       18, 2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6, 1, 2, 3, 4,  5,  6,   //
157       1,  2,  3, 4,  5, 6, 5, 2,  3, 4, 5,  6, 1, 2, 3, 4,  5,  0,   //
158       8,  2,  3, 4,  3, 6, 1, -2, 3, 4, 5,  6, 1, 2, 3, -4, 5,  6,   //
159       1,  2,  3, 4,  5, 6, 1, 2,  3, 4, -5, 6, 1, 7, 3, 4,  -5, 6,   //
160       8,  2,  3, 4,  5, 6, 3, 2,  3, 4, 5,  6, 1, 2, 3, 4,  5,  6,   //
161       1,  -2, 2, 4,  5, 6, 1, 2,  3, 4, 5,  6, 1, 2, 3, 8,  5,  -6,  //
162       8,  2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6, 1, 2, 3, 4,  5,  6,   //
163       1,  2,  3, 4,  3, 6, 1, 2,  6, 4, 5,  6, 1, 2, 3, 4,  -5, 6,   //
164       8,  2,  3, 4,  5, 6, 7, 2,  3, 4, 5,  6, 1, 2, 3, 14, 5,  6,   //
165       1,  2,  3, -4, 5, 6, 1, 2,  3, 4, 5,  6, 1, 2, 3, 4,  5,  6,   //
166   };
167   std::vector<int32_t> i2i_size_ = {n_cell_, n_input_};
168   TfLiteTensor i2i_tensor_;
169 
170   // input_to_forget_weights.
171   std::vector<int8_t> i2f_ = {
172       1,  2,  3, 4,  5, 6, 5, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  0,   //
173       8,  2,  3, 4,  3, 6, 1, -2, 3, 4, 5,  6,  1,  2, 3, -4, 5,  6,   //
174       1,  2,  3, 4,  5, 6, 1, 2,  3, 4, -5, 6,  1,  7, 3, 4,  -5, 6,   //
175       8,  2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  6,   //
176       1,  2,  3, 4,  3, 6, 1, 2,  6, 4, 5,  6,  11, 2, 3, 4,  -5, 6,   //
177       8,  2,  3, 4,  5, 6, 7, 2,  3, 4, 5,  -6, 1,  2, 3, 14, 5,  6,   //
178       1,  2,  3, -4, 5, 6, 1, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  6,   //
179       18, 2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  6,   //
180       8,  2,  3, 4,  5, 6, 3, 2,  3, 4, 5,  6,  13, 2, 3, 4,  5,  6,   //
181       1,  -2, 2, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1,  2, 3, 8,  5,  -6,  //
182   };
183   std::vector<int32_t> i2f_size_ = {n_cell_, n_input_};
184   TfLiteTensor i2f_tensor_;
185 
186   // input_to_cell_weights.
187   std::vector<int8_t> i2c_ = {
188       1,  2,  3, 4,  5, 6, 5, 2,  3, 4, 5,  6,  1, 2, 3, 4,  5,  0,   //
189       1,  2,  3, 4,  3, 6, 1, 2,  6, 4, 5,  6,  1, 2, 3, 4,  -5, 6,   //
190       8,  2,  3, 4,  5, 6, 7, 2,  3, 4, 5,  16, 1, 2, 3, 14, 5,  6,   //
191       1,  2,  3, -4, 5, 6, 1, 2,  3, 4, 5,  6,  7, 2, 3, 4,  5,  6,   //
192       18, 2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1, 2, 3, 4,  5,  6,   //
193       8,  2,  3, 4,  5, 6, 3, 2,  3, 4, 5,  6,  1, 2, 3, 4,  5,  6,   //
194       1,  -2, 2, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1, 2, 3, 8,  5,  -6,  //
195       8,  2,  3, 4,  3, 6, 1, -2, 3, 4, 5,  6,  1, 2, 3, -4, 5,  6,   //
196       1,  2,  3, 4,  5, 6, 1, 2,  3, 4, -5, 6,  1, 7, 3, 4,  -5, 6,   //
197       8,  2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1, 2, 3, 4,  5,  6,   //
198   };
199   std::vector<int32_t> i2c_size_ = {n_cell_, n_input_};
200   TfLiteTensor i2c_tensor_;
201 
202   // input_to_output_weights.
203   std::vector<int8_t> i2o_ = {
204       1,  2,  3, 4,  5, 6, 1, 2,  3, 4, -5, 6,  1,  7, 3, 4,  -5, 6,   //
205       8,  2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  6,  -1, 2, 3, 4,  5,  6,   //
206       1,  2,  3, 4,  3, 6, 1, 2,  6, 4, 5,  6,  1,  2, 3, 4,  -5, 6,   //
207       8,  2,  3, 4,  5, 6, 7, 2,  3, 4, 5,  6,  1,  2, 3, 14, 5,  6,   //
208       18, 2,  3, 4,  5, 6, 1, 2,  3, 4, 5,  -6, 1,  2, 3, 4,  5,  6,   //
209       8,  2,  3, 4,  5, 6, 3, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  6,   //
210       1,  2,  3, 4,  5, 6, 5, 2,  3, 4, 5,  6,  1,  2, 3, 4,  5,  0,   //
211       8,  2,  3, 4,  3, 6, 1, -2, 3, 4, 5,  6,  1,  2, 3, -4, 5,  6,   //
212       1,  2,  3, -4, 5, 6, 1, 2,  3, 4, 5,  6,  -1, 2, 3, 4,  5,  6,   //
213       1,  -2, 2, 4,  5, 6, 1, 2,  3, 4, 5,  6,  1,  2, 3, 8,  5,  -6,  //
214   };
215   std::vector<int32_t> i2o_size_ = {n_cell_, n_input_};
216   TfLiteTensor i2o_tensor_;
217 
218   // recurrent_to_input_weights.
219   std::vector<int8_t> r2i_ = {
220       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
221       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
222       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
223       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
224       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
225       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
226   };
227   std::vector<int32_t> r2i_size_ = {n_cell_, n_output_};
228   TfLiteTensor r2i_tensor_;
229 
230   // recurrent_to_forget_weights.
231   std::vector<int8_t> r2f_ = {
232       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
233       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
234       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
235       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
236       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
237       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
238   };
239   std::vector<int32_t> r2f_size_ = {n_cell_, n_output_};
240   TfLiteTensor r2f_tensor_;
241 
242   // recurrent_to_cell_weights.
243   std::vector<int8_t> r2c_ = {
244       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
245       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
246       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
247       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
248       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
249       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
250   };
251   std::vector<int32_t> r2c_size_ = {n_cell_, n_output_};
252   TfLiteTensor r2c_tensor_;
253 
254   // recurrent_to_output_weights.
255   std::vector<int8_t> r2o_ = {
256       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
257       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
258       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
259       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
260       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
261       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
262   };
263   std::vector<int32_t> r2o_size_ = {n_cell_, n_output_};
264   TfLiteTensor r2o_tensor_;
265 
266   std::vector<int32_t> layer_norm_input_size_ = {n_cell_};
267   TfLiteTensor layer_norm_input_tensor_;
268 
269   TfLiteTensor layer_norm_forget_tensor_;
270   std::vector<int32_t> layer_norm_forget_size_ = {n_cell_};
271 
272   std::vector<int32_t> layer_norm_cell_size_ = {n_cell_};
273   TfLiteTensor layer_norm_cell_tensor_;
274 
275   std::vector<int32_t> layer_norm_output_size_ = {n_cell_};
276   TfLiteTensor layer_norm_output_tensor_;
277 
278   std::vector<int32_t> input_gate_bias_size_ = {n_cell_};
279   TfLiteTensor input_gate_bias_tensor_;
280 
281   std::vector<int32_t> forget_gate_bias_size_ = {n_cell_};
282   TfLiteTensor forget_gate_bias_tensor_;
283 
284   std::vector<int32_t> cell_gate_bias_size_ = {n_cell_};
285   TfLiteTensor cell_gate_bias_tensor_;
286 
287   std::vector<int32_t> output_gate_bias_size_ = {n_cell_};
288   TfLiteTensor output_gate_bias_tensor_;
289 
290   // projection_weights.
291   std::vector<int8_t> projection_ = {
292       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
293       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
294       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
295       8, 2, 3, 4, 5, 6, 1, 2,  3,  4,  //
296       6, 4, 5, 6, 1, 2, 3, 4,  -5, 6,  //
297       1, 2, 3, 4, 7, 3, 4, -5, 6,  3,  //
298   };
299   std::vector<int32_t> projection_size_ = {n_cell_, n_output_};
300   TfLiteTensor projection_tensor_;
301 
302   // projection_bias.
303   std::vector<int32_t> projection_bias_ = {
304       16, 4, 5, 6, 1, 1  //
305   };
306 
307   std::vector<int32_t> projection_bias_size_ = {n_output_};
308   TfLiteTensor projection_bias_tensor_;
309 
310   std::vector<int32_t> activation_size_ = {n_batch_, n_output_};
311   TfLiteTensor activation_tensor_;
312 
313   std::vector<int32_t> cell_size_ = {n_batch_, n_cell_};
314   TfLiteTensor cell_tensor_;
315 
316   std::vector<int32_t> output_size_ = {n_batch_, n_output_};
317   TfLiteTensor output_tensor_;
318 };
319 
320 class QuantizedLstmParam : public BaseLstmParam {
321  public:
322   // Getter methods.
GetInput()323   TfLiteTensor* GetInput() {
324     PackWeightToTensor(&input_tensor_, input_, input_size_);
325     input_tensor_.data.int8 = input_.data();
326     return &input_tensor_;
327   }
GetInputLayerNorm()328   TfLiteTensor* GetInputLayerNorm() {
329     PackWeightToTensor(&layer_norm_input_tensor_, layer_norm_input_,
330                        layer_norm_input_size_);
331     layer_norm_input_tensor_.data.i16 = layer_norm_input_.data();
332     return &layer_norm_input_tensor_;
333   }
GetForgetLayerNorm()334   TfLiteTensor* GetForgetLayerNorm() {
335     PackWeightToTensor(&layer_norm_forget_tensor_, layer_norm_forget_,
336                        layer_norm_forget_size_);
337     layer_norm_forget_tensor_.data.i16 = layer_norm_forget_.data();
338     return &layer_norm_forget_tensor_;
339   }
GetCellLayerNorm()340   TfLiteTensor* GetCellLayerNorm() {
341     PackWeightToTensor(&layer_norm_cell_tensor_, layer_norm_cell_,
342                        layer_norm_cell_size_);
343     layer_norm_cell_tensor_.data.i16 = layer_norm_cell_.data();
344     return &layer_norm_cell_tensor_;
345   }
GetOutputLayerNorm()346   TfLiteTensor* GetOutputLayerNorm() {
347     PackWeightToTensor(&layer_norm_output_tensor_, layer_norm_output_,
348                        layer_norm_output_size_);
349     layer_norm_output_tensor_.data.i16 = layer_norm_output_.data();
350     return &layer_norm_output_tensor_;
351   }
GetInputBias()352   TfLiteTensor* GetInputBias() {
353     PackWeightToTensor(&input_gate_bias_tensor_, input_gate_bias_,
354                        input_gate_bias_size_);
355     input_gate_bias_tensor_.data.i32 = input_gate_bias_.data();
356     return &input_gate_bias_tensor_;
357   }
GetForgetBias()358   TfLiteTensor* GetForgetBias() {
359     PackWeightToTensor(&forget_gate_bias_tensor_, forget_gate_bias_,
360                        forget_gate_bias_size_);
361     forget_gate_bias_tensor_.data.i32 = forget_gate_bias_.data();
362     return &forget_gate_bias_tensor_;
363   }
GetCellBias()364   TfLiteTensor* GetCellBias() {
365     PackWeightToTensor(&cell_gate_bias_tensor_, cell_gate_bias_,
366                        cell_gate_bias_size_);
367     cell_gate_bias_tensor_.data.i32 = cell_gate_bias_.data();
368     return &cell_gate_bias_tensor_;
369   }
GetOutputBias()370   TfLiteTensor* GetOutputBias() {
371     PackWeightToTensor(&output_gate_bias_tensor_, output_gate_bias_,
372                        output_gate_bias_size_);
373     output_gate_bias_tensor_.data.i32 = output_gate_bias_.data();
374     return &output_gate_bias_tensor_;
375   }
GetProjectionBias()376   TfLiteTensor* GetProjectionBias() {
377     PackWeightToTensor(&projection_bias_tensor_, projection_bias_,
378                        projection_bias_size_);
379     projection_bias_tensor_.data.i32 = projection_bias_.data();
380     return &projection_bias_tensor_;
381   }
382 
383   // Set up quantization parameters.
GetQuantParam()384   ops::builtin::lstm_eval::IntegerLstmParameter* GetQuantParam() {
385     integer_lstm_param_.effective_input_to_input_scale_a = 1808677632;
386     integer_lstm_param_.effective_input_to_input_scale_b = -1;
387     integer_lstm_param_.effective_recurrent_to_input_scale_a = 1078887680;
388     integer_lstm_param_.effective_recurrent_to_input_scale_b = -1;
389     integer_lstm_param_.effective_cell_to_input_scale_a = 1073741824;
390     integer_lstm_param_.effective_cell_to_input_scale_b = 1;
391     integer_lstm_param_.effective_input_to_forget_scale_a = 1845996800;
392     integer_lstm_param_.effective_input_to_forget_scale_b = -3;
393     integer_lstm_param_.effective_recurrent_to_forget_scale_a = 1477412736;
394     integer_lstm_param_.effective_recurrent_to_forget_scale_b = -2;
395     integer_lstm_param_.effective_cell_to_forget_scale_a = 1073741824;
396     integer_lstm_param_.effective_cell_to_forget_scale_b = 1;
397     integer_lstm_param_.effective_input_to_cell_scale_a = 1648385408;
398     integer_lstm_param_.effective_input_to_cell_scale_b = -2;
399     integer_lstm_param_.effective_recurrent_to_cell_scale_a = 1185544192,
400     integer_lstm_param_.effective_recurrent_to_cell_scale_b = -1;
401     integer_lstm_param_.effective_input_to_output_scale_a = 1328153600;
402     integer_lstm_param_.effective_input_to_output_scale_b = -1;
403     integer_lstm_param_.effective_recurrent_to_output_scale_a = 1479582592;
404     integer_lstm_param_.effective_recurrent_to_output_scale_b = -1;
405     integer_lstm_param_.effective_cell_to_output_scale_a = 1073741824,
406     integer_lstm_param_.effective_cell_to_output_scale_b = 1;
407     integer_lstm_param_.effective_proj_scale_a = 1105682560;
408     integer_lstm_param_.effective_proj_scale_b = -8;
409     integer_lstm_param_.effective_hidden_scale_a = 0;
410     integer_lstm_param_.effective_hidden_scale_b = 0;
411     integer_lstm_param_.layer_norm_input_scale_a = 2011617664;
412     integer_lstm_param_.layer_norm_input_scale_b = -11;
413     integer_lstm_param_.layer_norm_forget_scale_a = 1968024960;
414     integer_lstm_param_.layer_norm_forget_scale_b = -13;
415     integer_lstm_param_.layer_norm_cell_scale_a = 1097334528,
416     integer_lstm_param_.layer_norm_cell_scale_b = -12;
417     integer_lstm_param_.layer_norm_output_scale_a = 1837163008;
418     integer_lstm_param_.layer_norm_output_scale_b = -12;
419     integer_lstm_param_.quantized_cell_clip = 20480;
420     integer_lstm_param_.quantized_proj_clip = 0;
421     integer_lstm_param_.cell_scale = -11;
422     integer_lstm_param_.input_variance_guard = 1;
423     integer_lstm_param_.forget_variance_guard = 2;
424     integer_lstm_param_.cell_variance_guard = 2;
425     integer_lstm_param_.output_variance_guard = 1;
426     integer_lstm_param_.hidden_zp = 0;
427     integer_lstm_param_.input_to_forget_effective_bias.reset(
428         new int32_t[n_cell_]);
429     integer_lstm_param_.recurrent_to_forget_effective_bias.reset(
430         new int32_t[n_cell_]);
431     integer_lstm_param_.input_to_cell_effective_bias.reset(
432         new int32_t[n_cell_]);
433     integer_lstm_param_.recurrent_to_cell_effective_bias.reset(
434         new int32_t[n_cell_]);
435     integer_lstm_param_.input_to_output_effective_bias.reset(
436         new int32_t[n_cell_]);
437     integer_lstm_param_.recurrent_to_output_effective_bias.reset(
438         new int32_t[n_cell_]);
439     integer_lstm_param_.input_to_input_effective_bias.reset(
440         new int32_t[n_cell_]);
441     integer_lstm_param_.recurrent_to_input_effective_bias.reset(
442         new int32_t[n_cell_]);
443     integer_lstm_param_.projection_effective_bias.reset(new int32_t[n_output_]);
444     std::fill_n(integer_lstm_param_.input_to_forget_effective_bias.get(),
445                 n_cell_, 152);
446     std::fill_n(integer_lstm_param_.recurrent_to_forget_effective_bias.get(),
447                 n_cell_, 315);
448     std::fill_n(integer_lstm_param_.input_to_cell_effective_bias.get(), n_cell_,
449                 165);
450     std::fill_n(integer_lstm_param_.recurrent_to_cell_effective_bias.get(),
451                 n_cell_, 1165);
452     std::fill_n(integer_lstm_param_.input_to_output_effective_bias.get(),
453                 n_cell_, 159);
454     std::fill_n(integer_lstm_param_.recurrent_to_output_effective_bias.get(),
455                 n_cell_, 915);
456     std::fill_n(integer_lstm_param_.input_to_input_effective_bias.get(),
457                 n_cell_, -15);
458     std::fill_n(integer_lstm_param_.recurrent_to_input_effective_bias.get(),
459                 n_cell_, 315);
460     std::fill_n(integer_lstm_param_.projection_effective_bias.get(), n_output_,
461                 115);
462     return &integer_lstm_param_;
463   }
464 
465   // Create scratch buffers.
GetScratch0()466   TfLiteTensor* GetScratch0() {
467     PackWeightToTensor(&scratch0_tensor_, scratch0_, scratch0_size_);
468     scratch0_tensor_.data.i16 = scratch0_.data();
469     return &scratch0_tensor_;
470   }
GetScratch1()471   TfLiteTensor* GetScratch1() {
472     PackWeightToTensor(&scratch1_tensor_, scratch1_, scratch1_size_);
473     scratch1_tensor_.data.i16 = scratch1_.data();
474     return &scratch1_tensor_;
475   }
GetScratch2()476   TfLiteTensor* GetScratch2() {
477     PackWeightToTensor(&scratch2_tensor_, scratch2_, scratch2_size_);
478     scratch2_tensor_.data.i16 = scratch2_.data();
479     return &scratch2_tensor_;
480   }
GetScratch3()481   TfLiteTensor* GetScratch3() {
482     PackWeightToTensor(&scratch3_tensor_, scratch3_, scratch3_size_);
483     scratch3_tensor_.data.i16 = scratch3_.data();
484     return &scratch3_tensor_;
485   }
GetScratch4()486   TfLiteTensor* GetScratch4() {
487     PackWeightToTensor(&scratch4_tensor_, scratch4_, scratch4_size_);
488     scratch4_tensor_.data.int8 = scratch4_.data();
489     return &scratch4_tensor_;
490   }
GetScratch5()491   TfLiteTensor* GetScratch5() {
492     PackWeightToTensor(&scratch5_tensor_, scratch5_, scratch5_size_);
493     scratch5_tensor_.data.i32 = scratch5_.data();
494     return &scratch5_tensor_;
495   }
GetActivation()496   TfLiteTensor* GetActivation() {
497     PackWeightToTensor(&activation_tensor_, activation_, activation_size_);
498     activation_tensor_.data.int8 = activation_.data();
499     activation_tensor_.params.zero_point = 50;
500     return &activation_tensor_;
501   }
GetOutput()502   TfLiteTensor* GetOutput() {
503     PackWeightToTensor(&output_tensor_, output_, output_size_);
504     output_tensor_.data.int8 = output_.data();
505     return &output_tensor_;
506   }
GetCell()507   TfLiteTensor* GetCell() {
508     PackWeightToTensor(&cell_tensor_, cell_, cell_size_);
509     cell_tensor_.data.i16 = cell_.data();
510     return &cell_tensor_;
511   }
~QuantizedLstmParam()512   ~QuantizedLstmParam() {
513     TfLiteIntArrayFree(scratch0_tensor_.dims);
514     TfLiteIntArrayFree(scratch1_tensor_.dims);
515     TfLiteIntArrayFree(scratch2_tensor_.dims);
516     TfLiteIntArrayFree(scratch3_tensor_.dims);
517     TfLiteIntArrayFree(scratch4_tensor_.dims);
518     TfLiteIntArrayFree(scratch5_tensor_.dims);
519   }
520 
521  private:
522   // input.
523   std::vector<int8_t> input_ = {
524       8, 2, 3,  4, 5, 6, 1, -2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,  //
525       1, 2, -3, 4, 5, 6, 1, 2,  3, 4, 5, 6, 1, 2, 3, 4, 5, 6,  //
526   };
527 
528   std::vector<int16_t> layer_norm_input_ = {8, 2, 3, 4, 5, 6, 1, 2, 3, 4};
529 
530   // forget_layer_norm_coefficient.
531   std::vector<int16_t> layer_norm_forget_ = {
532       1, 2, 3, 4, 7, 3, 4, -5, 6, 3,  //
533   };
534 
535   // cell_layer_norm_coefficients.
536   std::vector<int16_t> layer_norm_cell_ = {
537       6, 4, 5, 6, 1, 2, 3, 4, -5, 6,  //
538   };
539 
540   // output_layer_norm_coefficients.
541   std::vector<int16_t> layer_norm_output_ = {
542       16, 4, 5, 6, 1, 1, 3, 4, -5, 6,  //
543   };
544 
545   // input_gate_bias.
546   std::vector<int32_t> input_gate_bias_ = {
547       16, 4, 5, 6, 1, 1, 3, 4, -5, 6,  //
548   };
549 
550   // forget_gate_bias.
551   std::vector<int32_t> forget_gate_bias_ = {
552       16, 4, 5, 6, 1, 1, 3, 4, -5, 6,  //
553   };
554 
555   // cell_gate_bias.
556   std::vector<int32_t> cell_gate_bias_ = {
557       16, 4, 5, 6, 1, 1, 3, 4, -5, 6,  //
558   };
559 
560   // output_gate_bias.
561   std::vector<int32_t> output_gate_bias_ = {
562       16, 4, 5, 6, 1, 1, 3, 4, -5, 6,  //
563   };
564 
565   // activation.
566   std::vector<int8_t> activation_;
567 
568   // cell.
569   std::vector<int16_t> cell_ = {
570       16, 4,  5, 6, 1, 1, 3, 4, -5, 6,  //
571       1,  14, 5, 6, 1, 1, 3, 4, -5, 6,  //
572   };
573 
574   // output.
575   std::vector<int8_t> output_ = {
576       1, 1, 3, 4, -5, 6,  //
577       1, 4, 3, 4, -5, 6,  //
578   };
579 
580   // quantized_lstm_param
581   ops::builtin::lstm_eval::IntegerLstmParameter integer_lstm_param_;
582 
583   // 5 scratch buffers.
584   std::vector<int16_t> scratch0_;
585   std::vector<int32_t> scratch0_size_ = {n_batch_, n_cell_};
586   TfLiteTensor scratch0_tensor_;
587   std::vector<int16_t> scratch1_;
588   std::vector<int32_t> scratch1_size_ = {n_batch_, n_cell_};
589   TfLiteTensor scratch1_tensor_;
590   std::vector<int16_t> scratch2_;
591   std::vector<int32_t> scratch2_size_ = {n_batch_, n_cell_};
592   TfLiteTensor scratch2_tensor_;
593   std::vector<int16_t> scratch3_;
594   std::vector<int32_t> scratch3_size_ = {n_batch_, n_cell_};
595   TfLiteTensor scratch3_tensor_;
596   std::vector<int8_t> scratch4_;
597   std::vector<int32_t> scratch4_size_ = {n_batch_, n_cell_};
598   TfLiteTensor scratch4_tensor_;
599   std::vector<int32_t> scratch5_;
600   std::vector<int32_t> scratch5_size_ = {n_batch_, n_cell_};
601   TfLiteTensor scratch5_tensor_;
602 };
603 
TestOneFullyQuantizedLSTM()604 void TestOneFullyQuantizedLSTM() {
605   CpuBackendContext context;
606   QuantizedLstmParam one_parameter;
607   auto activation = one_parameter.GetActivation();
608   auto output = one_parameter.GetOutput();
609   auto cell = one_parameter.GetCell();
610   auto param = one_parameter.GetQuantParam();
611   ops::builtin::lstm_eval::EvalInteger8x8_16(
612       one_parameter.GetInput(), one_parameter.Geti2i(), one_parameter.Geti2f(),
613       one_parameter.Geti2c(), one_parameter.Geti2o(), one_parameter.Getr2i(),
614       one_parameter.Getr2f(), one_parameter.Getr2c(), one_parameter.Getr2o(),
615       nullptr, nullptr, nullptr, one_parameter.GetInputLayerNorm(),
616       one_parameter.GetForgetLayerNorm(), one_parameter.GetCellLayerNorm(),
617       one_parameter.GetOutputLayerNorm(), one_parameter.GetInputBias(),
618       one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
619       one_parameter.GetOutputBias(), one_parameter.GetProjection(),
620       one_parameter.GetProjectionBias(), nullptr, /*forward_sequence=*/true,
621       /*time_major=*/true, param, activation, cell, output,
622       one_parameter.GetScratch0(), one_parameter.GetScratch1(),
623       one_parameter.GetScratch2(), one_parameter.GetScratch3(),
624       one_parameter.GetScratch4(), one_parameter.GetScratch5(), &context);
625 
626   // Verify results.
627   const std::vector<int16_t> expected_cell = {
628       7, 1, 3, 2, 0, 1, 0, 2, -2, 4, 1, 6, 4, 3, 0, 1, 0, 2, -2, 4,
629   };
630   const std::vector<int8_t> expected_activation = {
631       50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
632   };
633   EXPECT_TRUE(ArrayEq(cell->data.i16, expected_cell.data(), 20));
634   EXPECT_TRUE(ArrayEq(activation->data.int8, expected_activation.data(), 12));
635   EXPECT_TRUE(ArrayEq(output->data.int8, expected_activation.data(), 12));
636 }
637 
TEST(TestOneFullyQuantizedLSTM,TestOneFullyQuantizedLSTM)638 TEST(TestOneFullyQuantizedLSTM, TestOneFullyQuantizedLSTM) {
639   TestOneFullyQuantizedLSTM();
640 }
641 
642 class HybridLstmParam : public BaseLstmParam {
643  public:
GetFloatOutput()644   TfLiteTensor* GetFloatOutput() {
645     PackWeightToTensor(&output_tensor_, output_float_, output_size_);
646     output_tensor_.data.f = output_float_.data();
647     return &output_tensor_;
648   }
GetLSTMParam()649   const TfLiteLSTMParams GetLSTMParam() {
650     return {kTfLiteActRelu, 0, 0, kTfLiteLSTMFullKernel, true};
651   }
GetScratchBuffer()652   TfLiteTensor* GetScratchBuffer() {
653     PackWeightToTensor(&scratch_buffer_tensor_, scratch_buffer_,
654                        scratch_buffer_size_);
655     scratch_buffer_tensor_.data.f = scratch_buffer_.data();
656     return &scratch_buffer_tensor_;
657   }
GetInputScalingFactors()658   TfLiteTensor* GetInputScalingFactors() {
659     PackWeightToTensor(&input_sf_tensor_, input_sf_,
660                        quantization_extra_scratch_buffer_sizes_);
661     input_sf_tensor_.data.f = input_sf_.data();
662     return &input_sf_tensor_;
663   }
GetAuxInputScalingFactors()664   TfLiteTensor* GetAuxInputScalingFactors() {
665     PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_,
666                        quantization_extra_scratch_buffer_sizes_);
667     aux_input_sf_tensor_.data.f = aux_input_sf_.data();
668     return &aux_input_sf_tensor_;
669   }
GetOutputStateScalingFactors()670   TfLiteTensor* GetOutputStateScalingFactors() {
671     PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_,
672                        quantization_extra_scratch_buffer_sizes_);
673     output_state_sf_tensor_.data.f = output_state_sf_.data();
674     return &output_state_sf_tensor_;
675   }
GetProdScalingFactors()676   TfLiteTensor* GetProdScalingFactors() {
677     PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_,
678                        quantization_extra_scratch_buffer_sizes_);
679     prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data();
680     return &prod_scaling_factors_tensor_;
681   }
GetInputQuantized()682   TfLiteTensor* GetInputQuantized() {
683     PackWeightToTensor(&input_quantized_tensor_, input_quantized_, input_size_);
684     input_quantized_tensor_.data.int8 = input_quantized_.data();
685     return &input_quantized_tensor_;
686   }
GetActivationStateQuantized()687   TfLiteTensor* GetActivationStateQuantized() {
688     PackWeightToTensor(&activation_quantized_tensor_, activation_quantized_,
689                        activation_size_);
690     activation_quantized_tensor_.data.int8 = activation_quantized_.data();
691     return &activation_quantized_tensor_;
692   }
GetCellStateQuantized()693   TfLiteTensor* GetCellStateQuantized() {
694     PackWeightToTensor(&cell_quantized_tensor_, cell_quantized_, cell_size_);
695     cell_quantized_tensor_.data.int8 = cell_quantized_.data();
696     return &cell_quantized_tensor_;
697   }
GetInputZeroPoints()698   TfLiteTensor* GetInputZeroPoints() {
699     PackWeightToTensor(&input_zp_tensor_, input_zp_,
700                        quantization_extra_scratch_buffer_sizes_);
701     input_zp_tensor_.data.i32 = input_zp_.data();
702     return &input_zp_tensor_;
703   }
GetAuxInputZeroPoints()704   TfLiteTensor* GetAuxInputZeroPoints() {
705     PackWeightToTensor(&aux_input_zp_tensor_, aux_input_zp_,
706                        quantization_extra_scratch_buffer_sizes_);
707     aux_input_zp_tensor_.data.i32 = aux_input_zp_.data();
708     return &aux_input_zp_tensor_;
709   }
GetOutputStateZeroPoints()710   TfLiteTensor* GetOutputStateZeroPoints() {
711     PackWeightToTensor(&output_state_zp_tensor_, output_state_zp_,
712                        quantization_extra_scratch_buffer_sizes_);
713     output_state_zp_tensor_.data.i32 = output_state_zp_.data();
714     return &output_state_zp_tensor_;
715   }
GetRowSums()716   TfLiteTensor* GetRowSums() {
717     PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_);
718     row_sums_tensor_.data.i32 = row_sums_.data();
719     return &row_sums_tensor_;
720   }
GetFloatInput()721   TfLiteTensor* GetFloatInput() {
722     PackWeightToTensor(&input_tensor_, input_float_, input_size_);
723     input_tensor_.data.f = input_float_.data();
724     return &input_tensor_;
725   }
GetActivation()726   TfLiteTensor* GetActivation() {
727     PackWeightToTensor(&activation_tensor_, activation_state_,
728                        activation_size_);
729     activation_tensor_.data.f = activation_state_.data();
730     return &activation_tensor_;
731   }
GetCell()732   TfLiteTensor* GetCell() {
733     PackWeightToTensor(&cell_tensor_, cell_state_, cell_size_);
734     cell_tensor_.data.f = cell_state_.data();
735     return &cell_tensor_;
736   }
GetAccumScratchBuffer()737   TfLiteTensor* GetAccumScratchBuffer() {
738     PackWeightToTensor(&accum_scratch_tensor_, accum_scratch_,
739                        accum_scratch_size_);
740     accum_scratch_tensor_.data.i32 = accum_scratch_.data();
741     return &accum_scratch_tensor_;
742   }
GetInputBias()743   TfLiteTensor* GetInputBias() {
744     PackWeightToTensor(&input_gate_bias_tensor_, input_float_bias_,
745                        input_gate_bias_size_);
746     input_gate_bias_tensor_.data.f = input_float_bias_.data();
747     return &input_gate_bias_tensor_;
748   }
GetForgetBias()749   TfLiteTensor* GetForgetBias() {
750     PackWeightToTensor(&forget_gate_bias_tensor_, forget_float_bias_,
751                        forget_gate_bias_size_);
752     forget_gate_bias_tensor_.data.f = forget_float_bias_.data();
753     return &forget_gate_bias_tensor_;
754   }
GetCellBias()755   TfLiteTensor* GetCellBias() {
756     PackWeightToTensor(&cell_gate_bias_tensor_, cell_float_bias_,
757                        cell_gate_bias_size_);
758     cell_gate_bias_tensor_.data.f = cell_float_bias_.data();
759     return &cell_gate_bias_tensor_;
760   }
GetOutputBias()761   TfLiteTensor* GetOutputBias() {
762     PackWeightToTensor(&output_gate_bias_tensor_, output_float_bias_,
763                        output_gate_bias_size_);
764     output_gate_bias_tensor_.data.f = output_float_bias_.data();
765     return &output_gate_bias_tensor_;
766   }
GetProjectionBias()767   TfLiteTensor* GetProjectionBias() {
768     PackWeightToTensor(&projection_bias_tensor_, projection_float_bias_,
769                        projection_bias_size_);
770     projection_bias_tensor_.data.f = projection_float_bias_.data();
771     return &projection_bias_tensor_;
772   }
GetNumRowSums()773   int GetNumRowSums() { return n_row_sums_; }
GetInputLayerNorm()774   TfLiteTensor* GetInputLayerNorm() {
775     PackWeightToTensor(&layer_norm_input_tensor_, layer_norm_float_input_,
776                        layer_norm_input_size_);
777     layer_norm_input_tensor_.data.f = layer_norm_float_input_.data();
778     return &layer_norm_input_tensor_;
779   }
GetForgetLayerNorm()780   TfLiteTensor* GetForgetLayerNorm() {
781     PackWeightToTensor(&layer_norm_forget_tensor_, layer_norm_float_forget_,
782                        layer_norm_forget_size_);
783     layer_norm_forget_tensor_.data.f = layer_norm_float_forget_.data();
784     return &layer_norm_forget_tensor_;
785   }
GetCellLayerNorm()786   TfLiteTensor* GetCellLayerNorm() {
787     PackWeightToTensor(&layer_norm_cell_tensor_, layer_norm_float_cell_,
788                        layer_norm_cell_size_);
789     layer_norm_cell_tensor_.data.f = layer_norm_float_cell_.data();
790     return &layer_norm_cell_tensor_;
791   }
GetOutputLayerNorm()792   TfLiteTensor* GetOutputLayerNorm() {
793     PackWeightToTensor(&layer_norm_output_tensor_, layer_norm_float_output_,
794                        layer_norm_output_size_);
795     layer_norm_output_tensor_.data.f = layer_norm_float_output_.data();
796     return &layer_norm_output_tensor_;
797   }
addScale(TfLiteTensor * t,float scale)798   static TfLiteTensor* addScale(TfLiteTensor* t, float scale) {
799     t->params.scale = scale;
800     return t;
801   }
~HybridLstmParam()802   ~HybridLstmParam() {
803     TfLiteIntArrayFree(scratch_buffer_tensor_.dims);
804     TfLiteIntArrayFree(accum_scratch_tensor_.dims);
805     TfLiteIntArrayFree(input_sf_tensor_.dims);
806     TfLiteIntArrayFree(aux_input_sf_tensor_.dims);
807     TfLiteIntArrayFree(output_state_sf_tensor_.dims);
808     TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims);
809     TfLiteIntArrayFree(input_quantized_tensor_.dims);
810     TfLiteIntArrayFree(activation_quantized_tensor_.dims);
811     TfLiteIntArrayFree(cell_quantized_tensor_.dims);
812     TfLiteIntArrayFree(input_zp_tensor_.dims);
813     TfLiteIntArrayFree(aux_input_zp_tensor_.dims);
814     TfLiteIntArrayFree(output_state_zp_tensor_.dims);
815     TfLiteIntArrayFree(row_sums_tensor_.dims);
816   }
817 
818  private:
819   const int n_row_sums_ = 9;  // Number of weights + 1 for projection weights.
820 
821   std::vector<float> scratch_buffer_;
822   std::vector<int32_t> scratch_buffer_size_ = {n_batch_, n_cell_ * 4};
823   TfLiteTensor scratch_buffer_tensor_;
824 
825   std::vector<int32_t> quantization_extra_scratch_buffer_sizes_ = {n_batch_};
826   std::vector<float> input_sf_;
827   TfLiteTensor input_sf_tensor_;
828   std::vector<float> aux_input_sf_;
829   TfLiteTensor aux_input_sf_tensor_;
830   std::vector<float> output_state_sf_;
831   TfLiteTensor output_state_sf_tensor_;
832 
833   std::vector<float> prod_scaling_factors_;
834   TfLiteTensor prod_scaling_factors_tensor_;
835 
836   std::vector<int32_t> input_zp_;
837   TfLiteTensor input_zp_tensor_;
838   std::vector<int32_t> aux_input_zp_;
839   TfLiteTensor aux_input_zp_tensor_;
840   std::vector<int32_t> output_state_zp_;
841   TfLiteTensor output_state_zp_tensor_;
842 
843   std::vector<int8_t> input_quantized_;
844   TfLiteTensor input_quantized_tensor_;
845 
846   std::vector<int8_t> activation_quantized_;
847   TfLiteTensor activation_quantized_tensor_;
848 
849   std::vector<int8_t> cell_quantized_;
850   TfLiteTensor cell_quantized_tensor_;
851 
852   std::vector<float> cell_state_ = {
853       16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6,
854   };
855 
856   std::vector<int32_t> row_sums_;
857   std::vector<int32_t> row_sums_size_ = {n_row_sums_, n_cell_};
858   TfLiteTensor row_sums_tensor_;
859 
860   std::vector<float> activation_state_;
861 
862   std::vector<int32_t> accum_scratch_;
863   std::vector<int32_t> accum_scratch_size_ = {n_cell_, n_batch_};
864   TfLiteTensor accum_scratch_tensor_;
865   std::vector<float> output_float_ = {
866       1, 1, 3, 4, -5, 6,  //
867       1, 4, 3, 4, -5, 6,  //
868   };
869   std::vector<float> input_float_ = {
870       6.06, 7.66, 7.10, 9.32, 3.85, 0.33, 7.15, 1.56, 9.54,
871       5.30, 4.53, 0.19, 1.83, 4.60, 0.84, 5.08, 4.37, 9.92,  //
872       4.08, 3.79, 1.17, 8.99, 0.14, 9.22, 3.18, 2.97, 7.53,
873       0.59, 9.89, 9.13, 7.68, 0.63, 2.15, 4.31, 7.20, 4.09,  //
874   };
875   std::vector<float> input_float_bias_ = {
876       0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
877   };
878   std::vector<float> forget_float_bias_ = {
879       0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
880   };
881   std::vector<float> cell_float_bias_ = {
882       -11, -7, -4, -5, -1, -1, -2, -3.5, -3, -4,
883   };
884   std::vector<float> output_float_bias_ = {0.16, 0.4, 0.5, 0.6,  0.1,
885                                            0.1,  0.3, 0.4, -0.5, 0.6};
886   std::vector<float> projection_float_bias_ = {0, 0, 0, 0, 0, 0};
887   std::vector<float> layer_norm_float_input_ = {8, 2, 3, 4, 5, 6, 1, -2, 3, 4};
888   std::vector<float> layer_norm_float_forget_ = {
889       0.1, 0.2, 0.3, 0.4, 0.7, 0.3, 0.4, -0.5, 0.6, 0.3,  //
890   };
891   std::vector<float> layer_norm_float_cell_ = {
892       0.6, 0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.4, -0.5, 0.6,  //
893   };
894   std::vector<float> layer_norm_float_output_ = {
895       0.6, 0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.4, -0.5, 0.6,  //
896   };
897 };
898 
TestOneHybridAsymmLSTM()899 void TestOneHybridAsymmLSTM() {
900   CpuBackendContext context;
901   HybridLstmParam one_parameter;
902   auto activation = one_parameter.GetActivation();
903   auto output = one_parameter.GetFloatOutput();
904   auto cell = one_parameter.GetCell();
905   auto param = one_parameter.GetLSTMParam();
906   bool compute_row_sums = true;
907   constexpr float kDefaultScale = 18.0;
908   ops::builtin::lstm_eval::EvalHybrid(
909       one_parameter.GetFloatInput(),
910       HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr,
911       HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr,
912       HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr,
913       HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr,
914       HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr,
915       HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr,
916       HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr,
917       HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr,
918       /*cell_to_input_weights=*/nullptr,
919       /*cell_to_forget_weights=*/nullptr,
920       /*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
921       one_parameter.GetForgetLayerNorm(), one_parameter.GetCellLayerNorm(),
922       one_parameter.GetOutputLayerNorm(),
923       /*aux_input=*/nullptr,
924       /*aux_input_to_input_weights=*/nullptr,
925       /*aux_input_to_forget_weights=*/nullptr,
926       /*aux_input_to_cell_weights=*/nullptr,
927       /*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
928       one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
929       one_parameter.GetOutputBias(),
930       HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr,
931       one_parameter.GetProjectionBias(), &param,
932       /*forward_sequence=*/true,
933       /*time_major=*/true,
934       /*output_offset=*/0, one_parameter.GetScratchBuffer(),
935       one_parameter.GetInputScalingFactors(),
936       one_parameter.GetAuxInputScalingFactors(),
937       one_parameter.GetOutputStateScalingFactors(),
938       one_parameter.GetProdScalingFactors(),
939       /*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(),
940       /*aux_input_quantized=*/nullptr,
941       one_parameter.GetActivationStateQuantized(),
942       one_parameter.GetCellStateQuantized(), activation, cell,
943       one_parameter.GetAccumScratchBuffer(), output,
944       one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(),
945       one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(),
946       one_parameter.GetNumRowSums(), &compute_row_sums, &context);
947   const std::vector<float> expected_cell = {
948       7.83134,  1.96158, 2.18285, 3.28739,  0.483214,
949       0.618206, 1.21539, 1.4052,  -3.17735, 2.24296,  //
950       0.498944, 6.91104, 1.74126, 3.28993,  0.580477,
951       0.489936, 1.2527,  1.50157, -3.71849, 2.76743,  //
952   };
953   const std::vector<float> expected_activation = {
954       53.0403, 59.3623, 24.8493, 53.0403, 59.3623, 24.8493,  //
955       36.7559, 57.5202, 29.7217, 36.7559, 57.5202, 29.7217,
956   };
957   EXPECT_TRUE(ArrayFloatNear(cell->data.f, expected_cell.data(), 20, 1e-2));
958   EXPECT_TRUE(
959       ArrayFloatNear(activation->data.f, expected_activation.data(), 12, 1e-4));
960   EXPECT_TRUE(
961       ArrayFloatNear(output->data.f, expected_activation.data(), 12, 1e-4));
962 }
963 
TEST(TestOneHybridAsymmLSTM,TestOneHybridAsymmLSTM)964 TEST(TestOneHybridAsymmLSTM, TestOneHybridAsymmLSTM) {
965   TestOneHybridAsymmLSTM();
966 }
967 
968 }  // namespace
969 }  // namespace tflite
970