• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 
18 #include <cstddef>
19 
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_eval.h"
30 #include "tensorflow/lite/kernels/lstm_shared.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace unidirectional_sequence_lstm {
36 namespace {
37 
38 struct OpData {
39   // If the lstm is layer norm.
40   bool use_layer_norm;
41   // The scratch tensor index.
42   int scratch_tensor_index;
43   bool compute_row_sums = false;
44 
45   lstm_eval::IntegerLstmParameter integer_lstm_param;
46 };
47 
PopulateQuantizedLstmParams8x8_16(TfLiteContext * context,TfLiteNode * node,lstm_eval::IntegerLstmParameter * integer_lstm_param)48 TfLiteStatus PopulateQuantizedLstmParams8x8_16(
49     TfLiteContext* context, TfLiteNode* node,
50     lstm_eval::IntegerLstmParameter* integer_lstm_param) {
51   // Calculate quantized clip for projection and cell.
52   const auto* params =
53       static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
54   const float cell_clip = params->cell_clip;
55   const float proj_clip = params->proj_clip;
56 
57   const TfLiteTensor* cell_state =
58       GetVariableInput(context, node, lstm::full::kCellStateTensor);
59   TF_LITE_ENSURE(context, cell_state != nullptr);
60   TfLiteTensor* output_tensor;
61   TF_LITE_ENSURE_OK(
62       context,
63       GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor));
64 
65   auto* cell_state_params =
66       static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
67   auto* proj_params = static_cast<TfLiteAffineQuantization*>(
68       output_tensor->quantization.params);
69   if (cell_clip > 0.0) {
70     integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
71         std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
72         32767.0f));
73   } else {
74     integer_lstm_param->quantized_cell_clip = 0;
75   }
76   if (proj_clip > 0.0) {
77     integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
78         std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
79   } else {
80     integer_lstm_param->quantized_proj_clip = 0;
81   }
82 
83   // Calculate effective scales.
84   OpData* op_data = static_cast<OpData*>(node->user_data);
85   const bool use_layer_norm = op_data->use_layer_norm;
86 
87   const TfLiteTensor* input;
88   TF_LITE_ENSURE_OK(
89       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
90 
91   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
92       context, node, lstm::full::kInputToInputWeightsTensor);
93   const TfLiteTensor* input_to_forget_weights;
94   TF_LITE_ENSURE_OK(
95       context,
96       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
97                    &input_to_forget_weights));
98   const TfLiteTensor* input_to_cell_weights;
99   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
100                                           lstm::full::kInputToCellWeightsTensor,
101                                           &input_to_cell_weights));
102   const TfLiteTensor* input_to_output_weights;
103   TF_LITE_ENSURE_OK(
104       context,
105       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
106                    &input_to_output_weights));
107 
108   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
109       context, node, lstm::full::kRecurrentToInputWeightsTensor);
110   const TfLiteTensor* recurrent_to_forget_weights;
111   TF_LITE_ENSURE_OK(
112       context,
113       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
114                    &recurrent_to_forget_weights));
115   const TfLiteTensor* recurrent_to_cell_weights;
116   TF_LITE_ENSURE_OK(
117       context,
118       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
119                    &recurrent_to_cell_weights));
120   const TfLiteTensor* recurrent_to_output_weights;
121   TF_LITE_ENSURE_OK(
122       context,
123       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
124                    &recurrent_to_output_weights));
125 
126   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
127       context, node, lstm::full::kCellToInputWeightsTensor);
128   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
129       context, node, lstm::full::kCellToForgetWeightsTensor);
130   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
131       context, node, lstm::full::kCellToOutputWeightsTensor);
132 
133   const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
134       context, node, lstm::full::kInputLayerNormCoefficientsTensor);
135   const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
136       context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
137   const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
138       context, node, lstm::full::kCellLayerNormCoefficientsTensor);
139   const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
140       context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
141 
142   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
143       context, node, lstm::full::kProjectionWeightsTensor);
144 
145   TfLiteTensor* output_state =
146       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
147   TF_LITE_ENSURE(context, output_state != nullptr);
148 
149   // Since we have already checked that weights are all there or none, we can
150   // check the existence of only one to get the condition.
151   const bool use_cifg = (input_to_input_weights == nullptr);
152   const bool use_peephole = (cell_to_output_weights != nullptr);
153   const bool use_projection = (projection_weights != nullptr);
154 
155   // Get intermediate scales and zero points.
156   std::vector<float> intermediate_scale;
157   std::vector<int32> intermediate_zp;
158   for (int i = 0; i < 4; ++i) {
159     if (use_layer_norm) {
160       TfLiteTensor* intermediate;
161       TF_LITE_ENSURE_OK(context,
162                         GetIntermediatesSafe(context, node, i, &intermediate));
163       auto* params = static_cast<TfLiteAffineQuantization*>(
164           intermediate->quantization.params);
165       intermediate_scale.push_back(params->scale->data[0]);
166       intermediate_zp.push_back(params->zero_point->data[0]);
167     } else {
168       // Q3.12 for activation functions.
169       intermediate_scale.push_back(std::pow(2, -12));
170       intermediate_zp.push_back(0);
171     }
172   }
173   // In the absence of projection, hidden becomes otuput and this intermediate
174   // is ignored.
175   TfLiteTensor* hidden;
176   TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
177   auto* hidden_params =
178       static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
179   intermediate_scale.push_back(hidden_params->scale->data[0]);
180   intermediate_zp.push_back(hidden_params->zero_point->data[0]);
181 
182   // Scales.
183   const float default_scale = 1.0;
184   float input_scale = default_scale;
185   float input_to_input_weight_scale = default_scale;
186   float recurrent_to_input_weight_scale = default_scale;
187   float cell_to_input_weight_scale = default_scale;
188   float input_to_forget_weight_scale = default_scale;
189   float recurrent_to_forget_weight_scale = default_scale;
190   float cell_to_forget_weight_scale = default_scale;
191   float input_to_cell_weight_scale = default_scale;
192   float recurrent_to_cell_weight_scale = default_scale;
193   float input_to_output_weight_scale = default_scale;
194   float recurrent_to_output_weight_scale = default_scale;
195   float cell_to_output_weight_scale = default_scale;
196   float projection_weight_scale = default_scale;
197   float layer_norm_input_scale = default_scale;
198   float layer_norm_forget_scale = default_scale;
199   float layer_norm_cell_scale = default_scale;
200   float layer_norm_output_scale = default_scale;
201   float output_state_scale = default_scale;
202   int cell_scale = 1;
203 
204   // Effective scales.
205   float effective_input_to_input_scale = default_scale;
206   float effective_recurrent_to_input_scale = default_scale;
207   float effective_cell_to_input_scale = default_scale;
208   float effective_input_to_forget_scale = default_scale;
209   float effective_recurrent_to_forget_scale = default_scale;
210   float effective_cell_to_forget_scale = default_scale;
211   float effective_input_to_cell_scale = default_scale;
212   float effective_recurrent_to_cell_scale = default_scale;
213   float effective_input_to_output_scale = default_scale;
214   float effective_recurrent_to_output_scale = default_scale;
215   float effective_cell_to_output_scale = default_scale;
216   float effective_proj_scale = default_scale;
217   float effective_hidden_scale = default_scale;
218 
219   // Populate scales.
220   if (!use_cifg) {
221     input_to_input_weight_scale = input_to_input_weights->params.scale;
222     recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
223   }
224 
225   if (use_peephole) {
226     if (!use_cifg) {
227       cell_to_input_weight_scale = cell_to_input_weights->params.scale;
228     }
229     cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
230     cell_to_output_weight_scale = cell_to_output_weights->params.scale;
231   }
232 
233   if (use_layer_norm) {
234     if (!use_cifg) {
235       layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
236     }
237     layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
238     layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
239     layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
240   }
241 
242   if (use_projection) {
243     projection_weight_scale = projection_weights->params.scale;
244   }
245   output_state_scale = output_state->params.scale;
246 
247   input_to_forget_weight_scale = input_to_forget_weights->params.scale;
248   input_to_cell_weight_scale = input_to_cell_weights->params.scale;
249   input_to_output_weight_scale = input_to_output_weights->params.scale;
250   recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
251   recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
252   recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
253 
254   // Check cell state (already used above)
255   TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
256   // TF_LITE_ENSURE(context, cell_scale <= -9);
257   integer_lstm_param->cell_scale = cell_scale;
258   input_scale = input->params.scale;
259 
260   // Calculate effective scales.
261   if (!use_cifg) {
262     effective_input_to_input_scale =
263         input_to_input_weight_scale * input_scale / intermediate_scale[0];
264     effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
265                                          output_state_scale /
266                                          intermediate_scale[0];
267   }
268   effective_input_to_forget_scale =
269       input_to_forget_weight_scale * input_scale / intermediate_scale[1];
270   effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
271                                         output_state_scale /
272                                         intermediate_scale[1];
273 
274   effective_input_to_cell_scale =
275       input_to_cell_weight_scale * input_scale / intermediate_scale[2];
276   effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
277                                       output_state_scale /
278                                       intermediate_scale[2];
279 
280   effective_input_to_output_scale =
281       input_to_output_weight_scale * input_scale / intermediate_scale[3];
282   effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
283                                         output_state_scale /
284                                         intermediate_scale[3];
285 
286   effective_hidden_scale =
287       std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
288 
289   effective_proj_scale =
290       projection_weight_scale * intermediate_scale[4] / output_state_scale;
291 
292   if (use_peephole) {
293     if (!use_cifg) {
294       effective_cell_to_input_scale = std::pow(2, cell_scale) *  // NOLINT
295                                       cell_to_input_weight_scale /
296                                       intermediate_scale[0];
297     }
298     effective_cell_to_forget_scale = std::pow(2, cell_scale) *  // NOLINT
299                                      cell_to_forget_weight_scale /
300                                      intermediate_scale[1];
301     effective_cell_to_output_scale = std::pow(2, cell_scale) *  // NOLINT
302                                      cell_to_output_weight_scale /
303                                      intermediate_scale[3];
304   }
305 
306   // Decompose scales.
307   QuantizeMultiplier(effective_input_to_input_scale,
308                      &integer_lstm_param->effective_input_to_input_scale_a,
309                      &integer_lstm_param->effective_input_to_input_scale_b);
310   QuantizeMultiplier(effective_recurrent_to_input_scale,
311                      &integer_lstm_param->effective_recurrent_to_input_scale_a,
312                      &integer_lstm_param->effective_recurrent_to_input_scale_b);
313   QuantizeMultiplier(effective_cell_to_input_scale,
314                      &integer_lstm_param->effective_cell_to_input_scale_a,
315                      &integer_lstm_param->effective_cell_to_input_scale_b);
316   QuantizeMultiplier(effective_input_to_forget_scale,
317                      &integer_lstm_param->effective_input_to_forget_scale_a,
318                      &integer_lstm_param->effective_input_to_forget_scale_b);
319   QuantizeMultiplier(
320       effective_recurrent_to_forget_scale,
321       &integer_lstm_param->effective_recurrent_to_forget_scale_a,
322       &integer_lstm_param->effective_recurrent_to_forget_scale_b);
323   QuantizeMultiplier(effective_cell_to_forget_scale,
324                      &integer_lstm_param->effective_cell_to_forget_scale_a,
325                      &integer_lstm_param->effective_cell_to_forget_scale_b);
326   QuantizeMultiplier(effective_input_to_cell_scale,
327                      &integer_lstm_param->effective_input_to_cell_scale_a,
328                      &integer_lstm_param->effective_input_to_cell_scale_b);
329   QuantizeMultiplier(effective_recurrent_to_cell_scale,
330                      &integer_lstm_param->effective_recurrent_to_cell_scale_a,
331                      &integer_lstm_param->effective_recurrent_to_cell_scale_b);
332   QuantizeMultiplier(effective_input_to_output_scale,
333                      &integer_lstm_param->effective_input_to_output_scale_a,
334                      &integer_lstm_param->effective_input_to_output_scale_b);
335   QuantizeMultiplier(
336       effective_recurrent_to_output_scale,
337       &integer_lstm_param->effective_recurrent_to_output_scale_a,
338       &integer_lstm_param->effective_recurrent_to_output_scale_b);
339   QuantizeMultiplier(effective_cell_to_output_scale,
340                      &integer_lstm_param->effective_cell_to_output_scale_a,
341                      &integer_lstm_param->effective_cell_to_output_scale_b);
342   QuantizeMultiplier(effective_proj_scale,
343                      &integer_lstm_param->effective_proj_scale_a,
344                      &integer_lstm_param->effective_proj_scale_b);
345   QuantizeMultiplier(effective_hidden_scale,
346                      &integer_lstm_param->effective_hidden_scale_a,
347                      &integer_lstm_param->effective_hidden_scale_b);
348   QuantizeMultiplier(layer_norm_input_scale,
349                      &integer_lstm_param->layer_norm_input_scale_a,
350                      &integer_lstm_param->layer_norm_input_scale_b);
351   QuantizeMultiplier(layer_norm_forget_scale,
352                      &integer_lstm_param->layer_norm_forget_scale_a,
353                      &integer_lstm_param->layer_norm_forget_scale_b);
354   QuantizeMultiplier(layer_norm_cell_scale,
355                      &integer_lstm_param->layer_norm_cell_scale_a,
356                      &integer_lstm_param->layer_norm_cell_scale_b);
357   QuantizeMultiplier(layer_norm_output_scale,
358                      &integer_lstm_param->layer_norm_output_scale_a,
359                      &integer_lstm_param->layer_norm_output_scale_b);
360 
361   integer_lstm_param->hidden_zp = intermediate_zp[4];
362 
363   // 10000 is used to make sure the kernel logic does not overflow.
364   if (!use_cifg) {
365     integer_lstm_param->input_variance_guard =
366         std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
367   }
368   integer_lstm_param->forget_variance_guard =
369       std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
370   integer_lstm_param->cell_variance_guard =
371       std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
372   integer_lstm_param->output_variance_guard =
373       std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
374 
375   return kTfLiteOk;
376 }
377 
378 }  // namespace
379 
380 // Temporary tensors
381 enum TemporaryTensor {
382   kScratchBuffer = 0,
383   kInputQuantized = 1,
384   kOutputStateQuantized = 2,
385   kCellStateQuantized = 3,
386   kInputScalingFactors = 4,
387   kOutputStateScalingFactors = 5,
388   kProductScalingFactors = 6,
389   kRecoveredCellWeights = 7,
390   kAccumScratch = 8,
391   kInputZeroPoints = 9,
392   kOutputStateZeroPoints = 10,
393   kRowSums = 11,
394   kNumTemporaryTensors = 12,
395 };
396 
Init(TfLiteContext * context,const char * buffer,size_t length)397 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
398   auto* op_data = new OpData();
399   context->AddTensors(context, kNumTemporaryTensors,
400                       &op_data->scratch_tensor_index);
401   return op_data;
402 }
403 
Free(TfLiteContext * context,void * buffer)404 void Free(TfLiteContext* context, void* buffer) {
405   delete reinterpret_cast<OpData*>(buffer);
406 }
407 
408 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,bool use_layer_norm,bool is_integer)409 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
410                                         TfLiteNode* node, int n_input,
411                                         int n_output, int n_cell,
412                                         bool use_layer_norm, bool is_integer) {
413   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
414 
415   // Making sure clipping parameters have valid values.
416   // == 0 means no clipping
417   //  > 0 means clipping
418   TF_LITE_ENSURE(context, params->cell_clip >= 0);
419   TF_LITE_ENSURE(context, params->proj_clip >= 0);
420 
421   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
422       context, node, lstm::full::kInputToInputWeightsTensor);
423   if (input_to_input_weights != nullptr) {
424     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
425     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
426     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
427   }
428 
429   const TfLiteTensor* input_to_forget_weights;
430   TF_LITE_ENSURE_OK(
431       context,
432       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
433                    &input_to_forget_weights));
434   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
435   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
436   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
437 
438   const TfLiteTensor* input_to_cell_weights;
439   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
440                                           lstm::full::kInputToCellWeightsTensor,
441                                           &input_to_cell_weights));
442   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
443   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
444   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
445 
446   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
447       context, node, lstm::full::kRecurrentToInputWeightsTensor);
448   if (recurrent_to_input_weights != nullptr) {
449     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
450     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
451                       n_cell);
452     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
453                       n_output);
454   }
455 
456   const TfLiteTensor* recurrent_to_forget_weights;
457   TF_LITE_ENSURE_OK(
458       context,
459       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
460                    &recurrent_to_forget_weights));
461   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
462   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
463                     n_cell);
464   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
465                     n_output);
466 
467   const TfLiteTensor* recurrent_to_cell_weights;
468   TF_LITE_ENSURE_OK(
469       context,
470       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
471                    &recurrent_to_cell_weights));
472   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
473   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
474   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
475                     n_output);
476 
477   // We make sure the input-gate's parameters are either both present (regular
478   // LSTM) or not at all (CIFG-LSTM).
479   const bool cifg_weights_all_or_none =
480       ((input_to_input_weights != nullptr) &&
481        (recurrent_to_input_weights != nullptr)) ||
482       ((input_to_input_weights == nullptr) &&
483        (recurrent_to_input_weights == nullptr));
484   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
485 
486   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
487       context, node, lstm::full::kCellToInputWeightsTensor);
488   if (cell_to_input_weights != nullptr) {
489     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
490     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
491     TF_LITE_ENSURE_TYPES_EQ(
492         context, cell_to_input_weights->type,
493         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
494   }
495 
496   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
497       context, node, lstm::full::kCellToForgetWeightsTensor);
498   if (cell_to_forget_weights != nullptr) {
499     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
500     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
501     TF_LITE_ENSURE_TYPES_EQ(
502         context, cell_to_forget_weights->type,
503         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
504   }
505 
506   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
507       context, node, lstm::full::kCellToOutputWeightsTensor);
508   if (cell_to_output_weights != nullptr) {
509     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
510     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
511     TF_LITE_ENSURE_TYPES_EQ(
512         context, cell_to_output_weights->type,
513         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
514   }
515 
516   // Making sure the peephole weights are there all or none.
517   const bool use_cifg = (input_to_input_weights == nullptr);
518   const bool peephole_weights_all_or_none =
519       ((cell_to_input_weights != nullptr || use_cifg) &&
520        (cell_to_forget_weights != nullptr) &&
521        (cell_to_output_weights != nullptr)) ||
522       ((cell_to_input_weights == nullptr) &&
523        (cell_to_forget_weights == nullptr) &&
524        (cell_to_output_weights == nullptr));
525   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
526 
527   // Make sure the input gate bias is present only when not a CIFG-LSTM.
528   const TfLiteTensor* input_gate_bias =
529       GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
530   if (use_cifg) {
531     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
532   } else {
533     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
534     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
535     if (is_integer) {
536       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
537     } else {
538       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
539     }
540   }
541 
542   const TfLiteTensor* forget_gate_bias;
543   TF_LITE_ENSURE_OK(
544       context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
545                             &forget_gate_bias));
546   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
547   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
548   if (is_integer) {
549     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
550   } else {
551     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
552   }
553 
554   const TfLiteTensor* cell_gate_bias;
555   TF_LITE_ENSURE_OK(context,
556                     GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
557                                  &cell_gate_bias));
558   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
559   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
560   if (is_integer) {
561     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
562   } else {
563     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
564   }
565 
566   const TfLiteTensor* output_gate_bias;
567   TF_LITE_ENSURE_OK(
568       context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
569                             &output_gate_bias));
570   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
571   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
572   if (is_integer) {
573     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
574   } else {
575     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
576   }
577 
578   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
579       context, node, lstm::full::kProjectionWeightsTensor);
580   if (projection_weights != nullptr) {
581     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
582     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
583     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
584   }
585 
586   const TfLiteTensor* projection_bias =
587       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
588   if (projection_bias != nullptr) {
589     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
590     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
591     if (is_integer) {
592       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
593     } else {
594       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
595     }
596   }
597 
598   // Making sure the projection tensors are consistent:
599   // 1) If projection weight is not present, then projection bias should not be
600   // present.
601   // 2) If projection weight is present, then projection bias is optional.
602   // TODO(ghodrat): make sure this is correct.
603   const bool projecton_tensors_consistent =
604       ((projection_weights != nullptr) || (projection_bias == nullptr));
605   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
606 
607   if (use_layer_norm) {
608     const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
609         context, node, lstm::full::kInputLayerNormCoefficientsTensor);
610     if (use_cifg) {
611       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
612     } else {
613       TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
614       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
615       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
616                         n_cell);
617       if (is_integer) {
618         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
619                                 kTfLiteInt16);
620       } else {
621         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
622                                 kTfLiteFloat32);
623       }
624     }
625 
626     const TfLiteTensor* forget_layer_norm_coefficients;
627     TF_LITE_ENSURE_OK(
628         context, GetInputSafe(context, node,
629                               lstm::full::kForgetLayerNormCoefficientsTensor,
630                               &forget_layer_norm_coefficients));
631     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
632     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
633                       n_cell);
634     if (is_integer) {
635       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
636                               kTfLiteInt16);
637     } else {
638       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
639                               kTfLiteFloat32);
640     }
641 
642     const TfLiteTensor* cell_layer_norm_coefficients;
643     TF_LITE_ENSURE_OK(context,
644                       GetInputSafe(context, node,
645                                    lstm::full::kCellLayerNormCoefficientsTensor,
646                                    &cell_layer_norm_coefficients));
647     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
648     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
649                       n_cell);
650     if (is_integer) {
651       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
652                               kTfLiteInt16);
653     } else {
654       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
655                               kTfLiteFloat32);
656     }
657 
658     const TfLiteTensor* output_layer_norm_coefficients;
659     TF_LITE_ENSURE_OK(
660         context, GetInputSafe(context, node,
661                               lstm::full::kOutputLayerNormCoefficientsTensor,
662                               &output_layer_norm_coefficients));
663     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
664     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
665                       n_cell);
666     if (is_integer) {
667       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
668                               kTfLiteInt16);
669     } else {
670       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
671                               kTfLiteFloat32);
672     }
673   }
674 
675   return kTfLiteOk;
676 }
677 
PrecomputeZeroPointTimesWeightWithBias(TfLiteContext * context,int32_t zero_point,const TfLiteTensor * weight_tensor,const TfLiteTensor * bias_tensor,std::unique_ptr<int32_t[]> * output)678 TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
679     TfLiteContext* context, int32_t zero_point,
680     const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
681     std::unique_ptr<int32_t[]>* output) {
682   if (weight_tensor == nullptr) {
683     return kTfLiteOk;
684   }
685 
686   const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
687   TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
688   const int row = weight_shape.Dims(0);
689   const int col = weight_shape.Dims(1);
690   output->reset(new int32_t[row]);
691   if (bias_tensor == nullptr) {
692     memset(output->get(), 0, row * sizeof(int32_t));
693   } else {
694     const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
695     memcpy(output->get(), bias, row * sizeof(int32_t));
696   }
697   if (zero_point != 0) {
698     const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
699     tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
700                                                  output->get());
701   }
702   return kTfLiteOk;
703 }
704 
PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext * context,OpData * op_data,TfLiteNode * node)705 TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
706                                                        OpData* op_data,
707                                                        TfLiteNode* node) {
708   const TfLiteTensor* input;
709   TF_LITE_ENSURE_OK(
710       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
711   const TfLiteTensor* output_state =
712       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
713   TF_LITE_ENSURE(context, output_state != nullptr);
714 
715   const int32_t input_zero_point = -input->params.zero_point;
716   const int32_t output_state_zero_point = -output_state->params.zero_point;
717 
718   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
719       context, node, lstm::full::kInputToInputWeightsTensor);
720   const TfLiteTensor* input_to_forget_weights;
721   TF_LITE_ENSURE_OK(
722       context,
723       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
724                    &input_to_forget_weights));
725   const TfLiteTensor* input_to_cell_weights;
726   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
727                                           lstm::full::kInputToCellWeightsTensor,
728                                           &input_to_cell_weights));
729   const TfLiteTensor* input_to_output_weights;
730   TF_LITE_ENSURE_OK(
731       context,
732       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
733                    &input_to_output_weights));
734 
735   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
736       context, node, lstm::full::kRecurrentToInputWeightsTensor);
737   const TfLiteTensor* recurrent_to_forget_weights;
738   TF_LITE_ENSURE_OK(
739       context,
740       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
741                    &recurrent_to_forget_weights));
742   const TfLiteTensor* recurrent_to_cell_weights;
743   TF_LITE_ENSURE_OK(
744       context,
745       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
746                    &recurrent_to_cell_weights));
747   const TfLiteTensor* recurrent_to_output_weights;
748   TF_LITE_ENSURE_OK(
749       context,
750       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
751                    &recurrent_to_output_weights));
752 
753   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
754       context, node, lstm::full::kProjectionWeightsTensor);
755   const TfLiteTensor* projection_bias =
756       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
757 
758   lstm_eval::IntegerLstmParameter* integer_lstm_params =
759       &op_data->integer_lstm_param;
760 
761   const TfLiteTensor* intermediate =
762       &context->tensors[node->intermediates->data[4]];
763   const auto* params =
764       static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
765   const int32_t hidden_zp = params->zero_point->data[0];
766 
767   // Get bias and perform zero point calculation.
768   // When there is layer normalization, the gate bias does not apply to matmul
769   // directly:
770   //      y = ln(w * x + w * r + w * c) + b.
771   const bool is_layer_norm = op_data->use_layer_norm;
772 
773   // Forget gate.
774   const TfLiteTensor* forget_gate_bias =
775       is_layer_norm
776           ? nullptr
777           : GetInput(context, node, lstm::full::kForgetGateBiasTensor);
778   TF_LITE_ENSURE_OK(
779       context,
780       PrecomputeZeroPointTimesWeightWithBias(
781           context, input_zero_point, input_to_forget_weights, forget_gate_bias,
782           &(integer_lstm_params->input_to_forget_effective_bias)));
783 
784   TF_LITE_ENSURE_OK(
785       context,
786       PrecomputeZeroPointTimesWeightWithBias(
787           context, output_state_zero_point, recurrent_to_forget_weights,
788           nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
789 
790   // Modulation gate.
791   const TfLiteTensor* cell_gate_bias =
792       is_layer_norm ? nullptr
793                     : GetInput(context, node, lstm::full::kCellGateBiasTensor);
794   TF_LITE_ENSURE_OK(
795       context,
796       PrecomputeZeroPointTimesWeightWithBias(
797           context, input_zero_point, input_to_cell_weights, cell_gate_bias,
798           &(integer_lstm_params->input_to_cell_effective_bias)));
799   TF_LITE_ENSURE_OK(
800       context,
801       PrecomputeZeroPointTimesWeightWithBias(
802           context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
803           &(integer_lstm_params->recurrent_to_cell_effective_bias)));
804 
805   // Output gate.
806   const TfLiteTensor* output_gate_bias =
807       is_layer_norm
808           ? nullptr
809           : GetInput(context, node, lstm::full::kOutputGateBiasTensor);
810   TF_LITE_ENSURE_OK(
811       context,
812       PrecomputeZeroPointTimesWeightWithBias(
813           context, input_zero_point, input_to_output_weights, output_gate_bias,
814           &(integer_lstm_params->input_to_output_effective_bias)));
815 
816   TF_LITE_ENSURE_OK(
817       context,
818       PrecomputeZeroPointTimesWeightWithBias(
819           context, output_state_zero_point, recurrent_to_output_weights,
820           nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
821 
822   // Input gate. The calculation is only meaningful for non-cifg case.
823   const TfLiteTensor* input_gate_bias =
824       is_layer_norm ? nullptr
825                     : GetInput(context, node, lstm::full::kInputGateBiasTensor);
826   TF_LITE_ENSURE_OK(
827       context,
828       PrecomputeZeroPointTimesWeightWithBias(
829           context, input_zero_point, input_to_input_weights, input_gate_bias,
830           &(integer_lstm_params->input_to_input_effective_bias)));
831   TF_LITE_ENSURE_OK(
832       context,
833       PrecomputeZeroPointTimesWeightWithBias(
834           context, output_state_zero_point, recurrent_to_input_weights, nullptr,
835           &(integer_lstm_params->recurrent_to_input_effective_bias)));
836 
837   // Projection bias. The calculation is only meaningful for with projection.
838   TF_LITE_ENSURE_OK(context,
839                     PrecomputeZeroPointTimesWeightWithBias(
840                         context, hidden_zp, projection_weights, projection_bias,
841                         &(integer_lstm_params->projection_effective_bias)));
842   return kTfLiteOk;
843 }
844 
845 // Resize the output and  state tensors based on the sizes of the input tensors.
846 // Allocate a temporary scratch tensor. Also check that the sizes of the input
847 // tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)848 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
849   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
850   const int scratch_tensor_index = op_data->scratch_tensor_index;
851 
852   // Check we have all the inputs and outputs we need.
853   bool use_layer_norm = false;
854   if (node->inputs->size == 24) {
855     const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
856         context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
857     if (forget_layer_norm_coefficients == nullptr) {
858       use_layer_norm = false;
859     } else {
860       use_layer_norm = true;
861     }
862   } else if (node->inputs->size == 20) {
863     // This is deprecated and is only kept here for backward compatibility.
864     use_layer_norm = false;
865   } else {
866     context->ReportError(
867         context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
868         node->inputs->size);
869     return kTfLiteError;
870   }
871   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
872   op_data->use_layer_norm = use_layer_norm;
873 
874   // Inferring batch size, number of outputs and sequence length and
875   // number of cells from the input tensors.
876   const TfLiteTensor* input;
877   TF_LITE_ENSURE_OK(
878       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
879   const bool is_integer = input->type == kTfLiteInt8;
880   TF_LITE_ENSURE(context, input->dims->size > 1);
881   const auto* params =
882       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
883           node->builtin_data);
884   const bool time_major = params->time_major;
885   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
886   const int n_input = input->dims->data[2];
887 
888   const TfLiteTensor* input_to_output_weights;
889   TF_LITE_ENSURE_OK(
890       context,
891       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
892                    &input_to_output_weights));
893   const int n_cell = input_to_output_weights->dims->data[0];
894   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
895   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
896 
897   const TfLiteTensor* recurrent_to_output_weights;
898   TF_LITE_ENSURE_OK(
899       context,
900       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
901                    &recurrent_to_output_weights));
902   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
903   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
904                     n_cell);
905   const int n_output = recurrent_to_output_weights->dims->data[1];
906 
907   // Check that input tensor dimensions matches with each other.
908   TF_LITE_ENSURE_OK(
909       context, CheckInputTensorDimensions(context, node, n_input, n_output,
910                                           n_cell, use_layer_norm, is_integer));
911 
912   // Get the pointer to output, output_state and cell_state buffer tensors.
913   TfLiteTensor* output;
914   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
915                                            lstm::full::kOutputTensor, &output));
916 
917   TfLiteTensor* output_state =
918       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
919   TF_LITE_ENSURE(context, output_state != nullptr);
920   TfLiteTensor* cell_state =
921       GetVariableInput(context, node, lstm::full::kCellStateTensor);
922   TF_LITE_ENSURE(context, cell_state != nullptr);
923 
924   // Check the shape of input state tensors.
925   // These tensor may be 1D or 2D. It's fine as long as the total size is
926   // correct.
927   TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
928   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
929 
930   // Resize the output tensors.
931   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
932   output_size->data[input->dims->size - 1] = n_output;
933   TF_LITE_ENSURE_OK(context,
934                     context->ResizeTensor(context, output, output_size));
935 
936   if (is_integer) {
937     const int num_intermediate_tensors = node->intermediates->size;
938     TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
939   }
940 
941   TfLiteIntArrayFree(node->temporaries);
942   if (IsHybridOp(input, input_to_output_weights)) {
943     node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
944   } else if (is_integer) {
945     node->temporaries = TfLiteIntArrayCreate(6);
946   } else {
947     node->temporaries = TfLiteIntArrayCreate(1);
948   }
949   node->temporaries->data[kScratchBuffer] =
950       scratch_tensor_index + kScratchBuffer;
951 
952   // Create a scratch buffer tensor.
953   TfLiteTensor* scratch_buffer;
954   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
955                                               &scratch_buffer));
956   scratch_buffer->type = input->type;
957   scratch_buffer->allocation_type = kTfLiteArenaRw;
958 
959   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
960       context, node, lstm::full::kInputToInputWeightsTensor);
961   const bool use_cifg = (input_to_input_weights == nullptr);
962   TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
963   scratch_buffer_size->data[0] = n_batch;
964   if (use_cifg) {
965     // Reserving space for Cell, Forget, Output gates
966     scratch_buffer_size->data[1] = n_cell * 3;
967   } else {
968     // Reserving space for Input, Cell, Forget, Output gates
969     scratch_buffer_size->data[1] = n_cell * 4;
970   }
971   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
972                                                    scratch_buffer_size));
973 
974   if (IsHybridOp(input, input_to_output_weights)) {
975     op_data->compute_row_sums = true;
976     // Allocate temporary tensors to store quantized values of input,
977     // output_state and cell_state tensors.
978     node->temporaries->data[kInputQuantized] =
979         scratch_tensor_index + kInputQuantized;
980     TfLiteTensor* input_quantized;
981     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
982                                                 &input_quantized));
983     input_quantized->type = input_to_output_weights->type;
984     input_quantized->allocation_type = kTfLiteArenaRw;
985     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
986       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
987       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
988                                                        input_quantized_size));
989     }
990     node->temporaries->data[kOutputStateQuantized] =
991         scratch_tensor_index + kOutputStateQuantized;
992     TfLiteTensor* output_state_quantized;
993     TF_LITE_ENSURE_OK(context,
994                       GetTemporarySafe(context, node, kOutputStateQuantized,
995                                        &output_state_quantized));
996     output_state_quantized->type = input_to_output_weights->type;
997     output_state_quantized->allocation_type = kTfLiteArenaRw;
998     if (!TfLiteIntArrayEqual(output_state_quantized->dims,
999                              output_state->dims)) {
1000       TfLiteIntArray* output_state_quantized_size =
1001           TfLiteIntArrayCopy(output_state->dims);
1002       TF_LITE_ENSURE_OK(context,
1003                         context->ResizeTensor(context, output_state_quantized,
1004                                               output_state_quantized_size));
1005     }
1006     node->temporaries->data[kCellStateQuantized] =
1007         scratch_tensor_index + kCellStateQuantized;
1008     TfLiteTensor* cell_state_quantized;
1009     TF_LITE_ENSURE_OK(context,
1010                       GetTemporarySafe(context, node, kCellStateQuantized,
1011                                        &cell_state_quantized));
1012     cell_state_quantized->type = input_to_output_weights->type;
1013     cell_state_quantized->allocation_type = kTfLiteArenaRw;
1014     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1015       TfLiteIntArray* cell_state_quantized_size =
1016           TfLiteIntArrayCopy(cell_state->dims);
1017       TF_LITE_ENSURE_OK(context,
1018                         context->ResizeTensor(context, cell_state_quantized,
1019                                               cell_state_quantized_size));
1020     }
1021 
1022     // Allocate temporary tensors to store scaling factors and product scaling
1023     // factors. The latter is a convenience storage which allows to quantize
1024     // a vector once (which produces the scaling factors) and multiply it with
1025     // different matrices (which requires multiplying the scaling factors with
1026     // the scaling factor of the matrix).
1027     node->temporaries->data[kInputScalingFactors] =
1028         op_data->scratch_tensor_index + kInputScalingFactors;
1029     TfLiteTensor* input_sf;
1030     TF_LITE_ENSURE_OK(
1031         context,
1032         GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1033     input_sf->type = kTfLiteFloat32;
1034     input_sf->allocation_type = kTfLiteArenaRw;
1035     int scaling_dims[1] = {n_batch};
1036     if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1037       TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1038       input_sf_size->data[0] = n_batch;
1039       TF_LITE_ENSURE_OK(
1040           context, context->ResizeTensor(context, input_sf, input_sf_size));
1041     }
1042     node->temporaries->data[kOutputStateScalingFactors] =
1043         op_data->scratch_tensor_index + kOutputStateScalingFactors;
1044     TfLiteTensor* output_state_sf;
1045     TF_LITE_ENSURE_OK(
1046         context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1047                                   &output_state_sf));
1048     output_state_sf->type = kTfLiteFloat32;
1049     output_state_sf->allocation_type = kTfLiteArenaRw;
1050     if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1051       TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1052       output_state_sf_size->data[0] = n_batch;
1053       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1054                                                        output_state_sf_size));
1055     }
1056     node->temporaries->data[kProductScalingFactors] =
1057         scratch_tensor_index + kProductScalingFactors;
1058     TfLiteTensor* prod_scaling_factors;
1059     TF_LITE_ENSURE_OK(context,
1060                       GetTemporarySafe(context, node, kProductScalingFactors,
1061                                        &prod_scaling_factors));
1062     prod_scaling_factors->type = kTfLiteFloat32;
1063     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1064     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1065                                    scaling_dims)) {
1066       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1067       prod_scaling_factors_size->data[0] = n_batch;
1068       TF_LITE_ENSURE_OK(context,
1069                         context->ResizeTensor(context, prod_scaling_factors,
1070                                               prod_scaling_factors_size));
1071     }
1072 
1073     // Allocate a temporary tensor to store the recovered cell weights. Since
1074     // this is used for diagonal matrices, only need to store n_cell values.
1075     node->temporaries->data[kRecoveredCellWeights] =
1076         scratch_tensor_index + kRecoveredCellWeights;
1077     TfLiteTensor* recovered_cell_weights;
1078     TF_LITE_ENSURE_OK(context,
1079                       GetTemporarySafe(context, node, kRecoveredCellWeights,
1080                                        &recovered_cell_weights));
1081     recovered_cell_weights->type = kTfLiteFloat32;
1082     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1083     int recovered_cell_dims[1] = {n_cell};
1084     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1085                                    recovered_cell_dims)) {
1086       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1087       recovered_cell_weights_size->data[0] = n_cell;
1088       TF_LITE_ENSURE_OK(context,
1089                         context->ResizeTensor(context, recovered_cell_weights,
1090                                               recovered_cell_weights_size));
1091     }
1092 
1093     // Allocate a temporary tensor to store the accumulated int32 values.
1094     node->temporaries->data[kAccumScratch] =
1095         scratch_tensor_index + kAccumScratch;
1096     TfLiteTensor* accum_scratch;
1097     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1098                                                 &accum_scratch));
1099     accum_scratch->type = kTfLiteInt32;
1100     accum_scratch->allocation_type = kTfLiteArenaRw;
1101     int accum_scratch_dims[2] = {n_cell, n_batch};
1102     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1103                                    accum_scratch_dims)) {
1104       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1105       accum_size->data[0] = n_cell;
1106       accum_size->data[1] = n_batch;
1107       TF_LITE_ENSURE_OK(
1108           context, context->ResizeTensor(context, accum_scratch, accum_size));
1109     }
1110     node->temporaries->data[kInputZeroPoints] =
1111         op_data->scratch_tensor_index + kInputZeroPoints;
1112     TfLiteTensor* input_zp;
1113     TF_LITE_ENSURE_OK(
1114         context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1115     input_zp->type = kTfLiteFloat32;
1116     input_zp->allocation_type = kTfLiteArenaRw;
1117     if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1118       TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1119       input_zp_size->data[0] = n_batch;
1120       TF_LITE_ENSURE_OK(
1121           context, context->ResizeTensor(context, input_zp, input_zp_size));
1122     }
1123     node->temporaries->data[kOutputStateZeroPoints] =
1124         op_data->scratch_tensor_index + kOutputStateZeroPoints;
1125     TfLiteTensor* output_state_zp;
1126     TF_LITE_ENSURE_OK(context,
1127                       GetTemporarySafe(context, node, kOutputStateZeroPoints,
1128                                        &output_state_zp));
1129     output_state_zp->type = kTfLiteFloat32;
1130     output_state_zp->allocation_type = kTfLiteArenaRw;
1131     if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1132       TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1133       output_state_zp_size->data[0] = n_batch;
1134       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1135                                                        output_state_zp_size));
1136     }
1137     node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
1138     TfLiteTensor* row_sums;
1139     TF_LITE_ENSURE_OK(context,
1140                       GetTemporarySafe(context, node, kRowSums, &row_sums));
1141     row_sums->type = kTfLiteInt32;
1142     row_sums->allocation_type = kTfLiteArenaRwPersistent;
1143     int row_sums_rows = use_cifg ? 6 : 8;
1144     const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1145         context, node, lstm::full::kProjectionWeightsTensor);
1146     if (projection_weights != nullptr) {
1147       row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1148     }
1149     int row_sums_dims[2] = {row_sums_rows, n_cell};
1150     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1151       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1152       row_sums_size->data[0] = row_sums_dims[0];
1153       row_sums_size->data[1] = row_sums_dims[1];
1154       TF_LITE_ENSURE_OK(
1155           context, context->ResizeTensor(context, row_sums, row_sums_size));
1156     }
1157   }
1158 
1159   if (is_integer) {
1160     // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
1161     // This code path needs 5 intermediate tensors per Op.
1162     // Populate quantization parameters.
1163     PopulateQuantizedLstmParams8x8_16(context, node,
1164                                       &op_data->integer_lstm_param);
1165     // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1166     // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1167     // buffer with size n_batch * n_cell.
1168     //
1169     // Handle cifg case as well, which might save one buffer.
1170     for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1171       node->temporaries->data[scratch_index] =
1172           op_data->scratch_tensor_index + scratch_index;
1173       TfLiteTensor* scratch_tensor;
1174       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index,
1175                                                   &scratch_tensor));
1176 
1177       scratch_tensor->type = kTfLiteInt16;
1178       if (scratch_index == 4) {
1179         scratch_tensor->type = kTfLiteInt8;
1180       } else if (scratch_index == 5) {
1181         scratch_tensor->type = kTfLiteInt32;
1182       }
1183 
1184       scratch_tensor->allocation_type = kTfLiteArenaRw;
1185       const int scratch_dimension[2] = {n_batch, n_cell};
1186       if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1187                                      scratch_dimension)) {
1188         TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1189         scratch_buffer_size->data[0] = n_batch;
1190         scratch_buffer_size->data[1] = n_cell;
1191         TF_LITE_ENSURE_OK(context,
1192                           context->ResizeTensor(context, scratch_tensor,
1193                                                 scratch_buffer_size));
1194       }
1195     }
1196 
1197     // Populate precomputed zp * weight.
1198     TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1199                                    context, op_data, node));
1200   }
1201 
1202   return kTfLiteOk;
1203 }
1204 
Eval(TfLiteContext * context,TfLiteNode * node)1205 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1206   const auto* params =
1207       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
1208           node->builtin_data);
1209   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1210   const bool use_layer_norm = op_data->use_layer_norm;
1211   const bool time_major = params->time_major;
1212   const TfLiteTensor* input;
1213   TF_LITE_ENSURE_OK(
1214       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
1215 
1216   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
1217       context, node, lstm::full::kInputToInputWeightsTensor);
1218   const TfLiteTensor* input_to_forget_weights;
1219   TF_LITE_ENSURE_OK(
1220       context,
1221       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
1222                    &input_to_forget_weights));
1223   const TfLiteTensor* input_to_cell_weights;
1224   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
1225                                           lstm::full::kInputToCellWeightsTensor,
1226                                           &input_to_cell_weights));
1227   const TfLiteTensor* input_to_output_weights;
1228   TF_LITE_ENSURE_OK(
1229       context,
1230       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
1231                    &input_to_output_weights));
1232 
1233   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1234       context, node, lstm::full::kRecurrentToInputWeightsTensor);
1235   const TfLiteTensor* recurrent_to_forget_weights;
1236   TF_LITE_ENSURE_OK(
1237       context,
1238       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
1239                    &recurrent_to_forget_weights));
1240   const TfLiteTensor* recurrent_to_cell_weights;
1241   TF_LITE_ENSURE_OK(
1242       context,
1243       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
1244                    &recurrent_to_cell_weights));
1245   const TfLiteTensor* recurrent_to_output_weights;
1246   TF_LITE_ENSURE_OK(
1247       context,
1248       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
1249                    &recurrent_to_output_weights));
1250 
1251   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
1252       context, node, lstm::full::kCellToInputWeightsTensor);
1253   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
1254       context, node, lstm::full::kCellToForgetWeightsTensor);
1255   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
1256       context, node, lstm::full::kCellToOutputWeightsTensor);
1257 
1258   const TfLiteTensor* input_gate_bias =
1259       GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
1260   const TfLiteTensor* forget_gate_bias;
1261   TF_LITE_ENSURE_OK(
1262       context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
1263                             &forget_gate_bias));
1264   const TfLiteTensor* cell_gate_bias;
1265   TF_LITE_ENSURE_OK(context,
1266                     GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
1267                                  &cell_gate_bias));
1268   const TfLiteTensor* output_gate_bias;
1269   TF_LITE_ENSURE_OK(
1270       context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
1271                             &output_gate_bias));
1272 
1273   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1274       context, node, lstm::full::kProjectionWeightsTensor);
1275   const TfLiteTensor* projection_bias =
1276       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
1277 
1278   TfLiteTensor* output_state =
1279       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
1280   TFLITE_DCHECK(output_state != nullptr);
1281   TfLiteTensor* cell_state =
1282       GetVariableInput(context, node, lstm::full::kCellStateTensor);
1283   TFLITE_DCHECK(cell_state != nullptr);
1284 
1285   const TfLiteTensor* input_layer_norm_coefficients =
1286       use_layer_norm
1287           ? GetOptionalInputTensor(
1288                 context, node, lstm::full::kInputLayerNormCoefficientsTensor)
1289           : nullptr;
1290   const TfLiteTensor* forget_layer_norm_coefficients =
1291       use_layer_norm ? GetInput(context, node,
1292                                 lstm::full::kForgetLayerNormCoefficientsTensor)
1293                      : nullptr;
1294   const TfLiteTensor* cell_layer_norm_coefficients =
1295       use_layer_norm ? GetInput(context, node,
1296                                 lstm::full::kCellLayerNormCoefficientsTensor)
1297                      : nullptr;
1298   const TfLiteTensor* output_layer_norm_coefficients =
1299       use_layer_norm ? GetInput(context, node,
1300                                 lstm::full::kOutputLayerNormCoefficientsTensor)
1301                      : nullptr;
1302 
1303   TfLiteTensor* output;
1304   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
1305                                            lstm::full::kOutputTensor, &output));
1306 
1307   // Copy out the LSTM specific params so they can be passed in the function.
1308   TfLiteLSTMParams lstm_params;
1309   lstm_params.activation = params->activation;
1310   lstm_params.cell_clip = params->cell_clip;
1311   lstm_params.proj_clip = params->proj_clip;
1312   lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
1313 
1314   switch (input_to_output_weights->type) {
1315     case kTfLiteFloat32: {
1316       // Index the scratch buffers pointers to the global scratch buffer.
1317       TfLiteTensor* scratch_buffer;
1318       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1319                                                   &scratch_buffer));
1320       return lstm_eval::EvalFloat(
1321           input, input_to_input_weights, input_to_forget_weights,
1322           input_to_cell_weights, input_to_output_weights,
1323           recurrent_to_input_weights, recurrent_to_forget_weights,
1324           recurrent_to_cell_weights, recurrent_to_output_weights,
1325           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1326           input_layer_norm_coefficients, forget_layer_norm_coefficients,
1327           cell_layer_norm_coefficients, output_layer_norm_coefficients,
1328           /*aux_input=*/nullptr,
1329           /*aux_input_to_input_weights=*/nullptr,
1330           /*aux_input_to_forget_weights=*/nullptr,
1331           /*aux_input_to_cell_weights=*/nullptr,
1332           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1333           forget_gate_bias, cell_gate_bias, output_gate_bias,
1334           projection_weights, projection_bias, &lstm_params,
1335           /*forward_sequence=*/true, time_major,
1336           /*output_offset=*/0, scratch_buffer, output_state, cell_state,
1337           output);
1338     }
1339     case kTfLiteUInt8:
1340     case kTfLiteInt8: {
1341       const bool is_hybrid = input->type == kTfLiteFloat32;
1342       if (is_hybrid) {
1343         // Index the scratch buffers pointers to the global scratch buffer.
1344         TfLiteTensor* scratch_buffer;
1345         TF_LITE_ENSURE_OK(
1346             context,
1347             GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer));
1348 
1349         OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1350         TfLiteTensor* row_sums;
1351         TF_LITE_ENSURE_OK(context,
1352                           GetTemporarySafe(context, node, kRowSums, &row_sums));
1353         const int row_sums_size = row_sums->dims->data[0];
1354         return lstm_eval::EvalHybrid(
1355             input, input_to_input_weights,
1356             /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
1357             /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
1358             /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
1359             /*input_to_output_weights_ledger*/ nullptr,
1360             recurrent_to_input_weights,
1361             /*recurrent_to_input_weights_ledger*/ nullptr,
1362             recurrent_to_forget_weights,
1363             /*recurrent_to_forget_weights_ledger*/ nullptr,
1364             recurrent_to_cell_weights,
1365             /*recurrent_to_cell_weights_ledger*/ nullptr,
1366             recurrent_to_output_weights,
1367             /*recurrent_to_output_weights_ledger*/ nullptr,
1368             cell_to_input_weights, cell_to_forget_weights,
1369             cell_to_output_weights, input_layer_norm_coefficients,
1370             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1371             output_layer_norm_coefficients,
1372             /*aux_input=*/nullptr,
1373             /*aux_input_to_input_weights=*/nullptr,
1374             /*aux_input_to_forget_weights=*/nullptr,
1375             /*aux_input_to_cell_weights=*/nullptr,
1376             /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1377             forget_gate_bias, cell_gate_bias, output_gate_bias,
1378             projection_weights, /*projection_weights_ledger*/ nullptr,
1379             projection_bias, &lstm_params,
1380             /*forward_sequence=*/true, time_major,
1381             /*output_offset=*/0, scratch_buffer,
1382             GetTemporary(context, node, kInputScalingFactors),
1383             /*aux_input_sf=*/nullptr,
1384             GetTemporary(context, node, kOutputStateScalingFactors),
1385             GetTemporary(context, node, kProductScalingFactors),
1386             GetTemporary(context, node, kRecoveredCellWeights),
1387             GetTemporary(context, node, kInputQuantized),
1388             /*aux_input_quantized=*/nullptr,
1389             GetTemporary(context, node, kOutputStateQuantized),
1390             GetTemporary(context, node, kCellStateQuantized), output_state,
1391             cell_state, GetTemporary(context, node, kAccumScratch), output,
1392             GetTemporary(context, node, kInputZeroPoints),
1393             /*aux_input_zp=*/nullptr,
1394             GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
1395             row_sums_size, &op_data->compute_row_sums,
1396             CpuBackendContext::GetFromContext(context));
1397       } else {
1398         TfLiteTensor* scratch0;
1399         TF_LITE_ENSURE_OK(context,
1400                           GetTemporarySafe(context, node, 0, &scratch0));
1401         TfLiteTensor* scratch1;
1402         TF_LITE_ENSURE_OK(context,
1403                           GetTemporarySafe(context, node, 1, &scratch1));
1404         TfLiteTensor* scratch2;
1405         TF_LITE_ENSURE_OK(context,
1406                           GetTemporarySafe(context, node, 2, &scratch2));
1407         TfLiteTensor* scratch3;
1408         TF_LITE_ENSURE_OK(context,
1409                           GetTemporarySafe(context, node, 3, &scratch3));
1410         TfLiteTensor* scratch4;
1411         TF_LITE_ENSURE_OK(context,
1412                           GetTemporarySafe(context, node, 4, &scratch4));
1413         TfLiteTensor* scratch5;
1414         TF_LITE_ENSURE_OK(context,
1415                           GetTemporarySafe(context, node, 5, &scratch5));
1416         return lstm_eval::EvalInteger8x8_16(
1417             input, input_to_input_weights, input_to_forget_weights,
1418             input_to_cell_weights, input_to_output_weights,
1419             recurrent_to_input_weights, recurrent_to_forget_weights,
1420             recurrent_to_cell_weights, recurrent_to_output_weights,
1421             cell_to_input_weights, cell_to_forget_weights,
1422             cell_to_output_weights, input_layer_norm_coefficients,
1423             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1424             output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
1425             cell_gate_bias, output_gate_bias, projection_weights,
1426             projection_bias, &lstm_params, /*forward_sequence=*/true,
1427             time_major, &op_data->integer_lstm_param, output_state, cell_state,
1428             output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
1429             CpuBackendContext::GetFromContext(context));
1430       }
1431     }
1432     default:
1433       TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1434                          TfLiteTypeGetName(input_to_output_weights->type));
1435       return kTfLiteError;
1436   }
1437   return kTfLiteOk;
1438 }
1439 }  // namespace unidirectional_sequence_lstm
1440 
Register_UNIDIRECTIONAL_SEQUENCE_LSTM()1441 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
1442   static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
1443                                  unidirectional_sequence_lstm::Free,
1444                                  unidirectional_sequence_lstm::Prepare,
1445                                  unidirectional_sequence_lstm::Eval};
1446   return &r;
1447 }
1448 
1449 }  // namespace builtin
1450 }  // namespace ops
1451 }  // namespace tflite
1452