1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <math.h>
17
18 #include <cstddef>
19
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_eval.h"
30 #include "tensorflow/lite/kernels/lstm_shared.h"
31
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace unidirectional_sequence_lstm {
36 namespace {
37
38 struct OpData {
39 // If the lstm is layer norm.
40 bool use_layer_norm;
41 // The scratch tensor index.
42 int scratch_tensor_index;
43 bool compute_row_sums = false;
44
45 lstm_eval::IntegerLstmParameter integer_lstm_param;
46 };
47
PopulateQuantizedLstmParams8x8_16(TfLiteContext * context,TfLiteNode * node,lstm_eval::IntegerLstmParameter * integer_lstm_param)48 TfLiteStatus PopulateQuantizedLstmParams8x8_16(
49 TfLiteContext* context, TfLiteNode* node,
50 lstm_eval::IntegerLstmParameter* integer_lstm_param) {
51 // Calculate quantized clip for projection and cell.
52 const auto* params =
53 static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
54 const float cell_clip = params->cell_clip;
55 const float proj_clip = params->proj_clip;
56
57 const TfLiteTensor* cell_state =
58 GetVariableInput(context, node, lstm::full::kCellStateTensor);
59 TF_LITE_ENSURE(context, cell_state != nullptr);
60 TfLiteTensor* output_tensor;
61 TF_LITE_ENSURE_OK(
62 context,
63 GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor));
64
65 auto* cell_state_params =
66 static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
67 auto* proj_params = static_cast<TfLiteAffineQuantization*>(
68 output_tensor->quantization.params);
69 if (cell_clip > 0.0) {
70 integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
71 std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
72 32767.0f));
73 } else {
74 integer_lstm_param->quantized_cell_clip = 0;
75 }
76 if (proj_clip > 0.0) {
77 integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
78 std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
79 } else {
80 integer_lstm_param->quantized_proj_clip = 0;
81 }
82
83 // Calculate effective scales.
84 OpData* op_data = static_cast<OpData*>(node->user_data);
85 const bool use_layer_norm = op_data->use_layer_norm;
86
87 const TfLiteTensor* input;
88 TF_LITE_ENSURE_OK(
89 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
90
91 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
92 context, node, lstm::full::kInputToInputWeightsTensor);
93 const TfLiteTensor* input_to_forget_weights;
94 TF_LITE_ENSURE_OK(
95 context,
96 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
97 &input_to_forget_weights));
98 const TfLiteTensor* input_to_cell_weights;
99 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
100 lstm::full::kInputToCellWeightsTensor,
101 &input_to_cell_weights));
102 const TfLiteTensor* input_to_output_weights;
103 TF_LITE_ENSURE_OK(
104 context,
105 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
106 &input_to_output_weights));
107
108 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
109 context, node, lstm::full::kRecurrentToInputWeightsTensor);
110 const TfLiteTensor* recurrent_to_forget_weights;
111 TF_LITE_ENSURE_OK(
112 context,
113 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
114 &recurrent_to_forget_weights));
115 const TfLiteTensor* recurrent_to_cell_weights;
116 TF_LITE_ENSURE_OK(
117 context,
118 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
119 &recurrent_to_cell_weights));
120 const TfLiteTensor* recurrent_to_output_weights;
121 TF_LITE_ENSURE_OK(
122 context,
123 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
124 &recurrent_to_output_weights));
125
126 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
127 context, node, lstm::full::kCellToInputWeightsTensor);
128 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
129 context, node, lstm::full::kCellToForgetWeightsTensor);
130 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
131 context, node, lstm::full::kCellToOutputWeightsTensor);
132
133 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
134 context, node, lstm::full::kInputLayerNormCoefficientsTensor);
135 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
136 context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
137 const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
138 context, node, lstm::full::kCellLayerNormCoefficientsTensor);
139 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
140 context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
141
142 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
143 context, node, lstm::full::kProjectionWeightsTensor);
144
145 TfLiteTensor* output_state =
146 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
147 TF_LITE_ENSURE(context, output_state != nullptr);
148
149 // Since we have already checked that weights are all there or none, we can
150 // check the existence of only one to get the condition.
151 const bool use_cifg = (input_to_input_weights == nullptr);
152 const bool use_peephole = (cell_to_output_weights != nullptr);
153 const bool use_projection = (projection_weights != nullptr);
154
155 // Get intermediate scales and zero points.
156 std::vector<float> intermediate_scale;
157 std::vector<int32> intermediate_zp;
158 for (int i = 0; i < 4; ++i) {
159 if (use_layer_norm) {
160 TfLiteTensor* intermediate;
161 TF_LITE_ENSURE_OK(context,
162 GetIntermediatesSafe(context, node, i, &intermediate));
163 auto* params = static_cast<TfLiteAffineQuantization*>(
164 intermediate->quantization.params);
165 intermediate_scale.push_back(params->scale->data[0]);
166 intermediate_zp.push_back(params->zero_point->data[0]);
167 } else {
168 // Q3.12 for activation functions.
169 intermediate_scale.push_back(std::pow(2, -12));
170 intermediate_zp.push_back(0);
171 }
172 }
173 // In the absence of projection, hidden becomes otuput and this intermediate
174 // is ignored.
175 TfLiteTensor* hidden;
176 TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
177 auto* hidden_params =
178 static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
179 intermediate_scale.push_back(hidden_params->scale->data[0]);
180 intermediate_zp.push_back(hidden_params->zero_point->data[0]);
181
182 // Scales.
183 const float default_scale = 1.0;
184 float input_scale = default_scale;
185 float input_to_input_weight_scale = default_scale;
186 float recurrent_to_input_weight_scale = default_scale;
187 float cell_to_input_weight_scale = default_scale;
188 float input_to_forget_weight_scale = default_scale;
189 float recurrent_to_forget_weight_scale = default_scale;
190 float cell_to_forget_weight_scale = default_scale;
191 float input_to_cell_weight_scale = default_scale;
192 float recurrent_to_cell_weight_scale = default_scale;
193 float input_to_output_weight_scale = default_scale;
194 float recurrent_to_output_weight_scale = default_scale;
195 float cell_to_output_weight_scale = default_scale;
196 float projection_weight_scale = default_scale;
197 float layer_norm_input_scale = default_scale;
198 float layer_norm_forget_scale = default_scale;
199 float layer_norm_cell_scale = default_scale;
200 float layer_norm_output_scale = default_scale;
201 float output_state_scale = default_scale;
202 int cell_scale = 1;
203
204 // Effective scales.
205 float effective_input_to_input_scale = default_scale;
206 float effective_recurrent_to_input_scale = default_scale;
207 float effective_cell_to_input_scale = default_scale;
208 float effective_input_to_forget_scale = default_scale;
209 float effective_recurrent_to_forget_scale = default_scale;
210 float effective_cell_to_forget_scale = default_scale;
211 float effective_input_to_cell_scale = default_scale;
212 float effective_recurrent_to_cell_scale = default_scale;
213 float effective_input_to_output_scale = default_scale;
214 float effective_recurrent_to_output_scale = default_scale;
215 float effective_cell_to_output_scale = default_scale;
216 float effective_proj_scale = default_scale;
217 float effective_hidden_scale = default_scale;
218
219 // Populate scales.
220 if (!use_cifg) {
221 input_to_input_weight_scale = input_to_input_weights->params.scale;
222 recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
223 }
224
225 if (use_peephole) {
226 if (!use_cifg) {
227 cell_to_input_weight_scale = cell_to_input_weights->params.scale;
228 }
229 cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
230 cell_to_output_weight_scale = cell_to_output_weights->params.scale;
231 }
232
233 if (use_layer_norm) {
234 if (!use_cifg) {
235 layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
236 }
237 layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
238 layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
239 layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
240 }
241
242 if (use_projection) {
243 projection_weight_scale = projection_weights->params.scale;
244 }
245 output_state_scale = output_state->params.scale;
246
247 input_to_forget_weight_scale = input_to_forget_weights->params.scale;
248 input_to_cell_weight_scale = input_to_cell_weights->params.scale;
249 input_to_output_weight_scale = input_to_output_weights->params.scale;
250 recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
251 recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
252 recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
253
254 // Check cell state (already used above)
255 TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
256 // TF_LITE_ENSURE(context, cell_scale <= -9);
257 integer_lstm_param->cell_scale = cell_scale;
258 input_scale = input->params.scale;
259
260 // Calculate effective scales.
261 if (!use_cifg) {
262 effective_input_to_input_scale =
263 input_to_input_weight_scale * input_scale / intermediate_scale[0];
264 effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
265 output_state_scale /
266 intermediate_scale[0];
267 }
268 effective_input_to_forget_scale =
269 input_to_forget_weight_scale * input_scale / intermediate_scale[1];
270 effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
271 output_state_scale /
272 intermediate_scale[1];
273
274 effective_input_to_cell_scale =
275 input_to_cell_weight_scale * input_scale / intermediate_scale[2];
276 effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
277 output_state_scale /
278 intermediate_scale[2];
279
280 effective_input_to_output_scale =
281 input_to_output_weight_scale * input_scale / intermediate_scale[3];
282 effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
283 output_state_scale /
284 intermediate_scale[3];
285
286 effective_hidden_scale =
287 std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
288
289 effective_proj_scale =
290 projection_weight_scale * intermediate_scale[4] / output_state_scale;
291
292 if (use_peephole) {
293 if (!use_cifg) {
294 effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT
295 cell_to_input_weight_scale /
296 intermediate_scale[0];
297 }
298 effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT
299 cell_to_forget_weight_scale /
300 intermediate_scale[1];
301 effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT
302 cell_to_output_weight_scale /
303 intermediate_scale[3];
304 }
305
306 // Decompose scales.
307 QuantizeMultiplier(effective_input_to_input_scale,
308 &integer_lstm_param->effective_input_to_input_scale_a,
309 &integer_lstm_param->effective_input_to_input_scale_b);
310 QuantizeMultiplier(effective_recurrent_to_input_scale,
311 &integer_lstm_param->effective_recurrent_to_input_scale_a,
312 &integer_lstm_param->effective_recurrent_to_input_scale_b);
313 QuantizeMultiplier(effective_cell_to_input_scale,
314 &integer_lstm_param->effective_cell_to_input_scale_a,
315 &integer_lstm_param->effective_cell_to_input_scale_b);
316 QuantizeMultiplier(effective_input_to_forget_scale,
317 &integer_lstm_param->effective_input_to_forget_scale_a,
318 &integer_lstm_param->effective_input_to_forget_scale_b);
319 QuantizeMultiplier(
320 effective_recurrent_to_forget_scale,
321 &integer_lstm_param->effective_recurrent_to_forget_scale_a,
322 &integer_lstm_param->effective_recurrent_to_forget_scale_b);
323 QuantizeMultiplier(effective_cell_to_forget_scale,
324 &integer_lstm_param->effective_cell_to_forget_scale_a,
325 &integer_lstm_param->effective_cell_to_forget_scale_b);
326 QuantizeMultiplier(effective_input_to_cell_scale,
327 &integer_lstm_param->effective_input_to_cell_scale_a,
328 &integer_lstm_param->effective_input_to_cell_scale_b);
329 QuantizeMultiplier(effective_recurrent_to_cell_scale,
330 &integer_lstm_param->effective_recurrent_to_cell_scale_a,
331 &integer_lstm_param->effective_recurrent_to_cell_scale_b);
332 QuantizeMultiplier(effective_input_to_output_scale,
333 &integer_lstm_param->effective_input_to_output_scale_a,
334 &integer_lstm_param->effective_input_to_output_scale_b);
335 QuantizeMultiplier(
336 effective_recurrent_to_output_scale,
337 &integer_lstm_param->effective_recurrent_to_output_scale_a,
338 &integer_lstm_param->effective_recurrent_to_output_scale_b);
339 QuantizeMultiplier(effective_cell_to_output_scale,
340 &integer_lstm_param->effective_cell_to_output_scale_a,
341 &integer_lstm_param->effective_cell_to_output_scale_b);
342 QuantizeMultiplier(effective_proj_scale,
343 &integer_lstm_param->effective_proj_scale_a,
344 &integer_lstm_param->effective_proj_scale_b);
345 QuantizeMultiplier(effective_hidden_scale,
346 &integer_lstm_param->effective_hidden_scale_a,
347 &integer_lstm_param->effective_hidden_scale_b);
348 QuantizeMultiplier(layer_norm_input_scale,
349 &integer_lstm_param->layer_norm_input_scale_a,
350 &integer_lstm_param->layer_norm_input_scale_b);
351 QuantizeMultiplier(layer_norm_forget_scale,
352 &integer_lstm_param->layer_norm_forget_scale_a,
353 &integer_lstm_param->layer_norm_forget_scale_b);
354 QuantizeMultiplier(layer_norm_cell_scale,
355 &integer_lstm_param->layer_norm_cell_scale_a,
356 &integer_lstm_param->layer_norm_cell_scale_b);
357 QuantizeMultiplier(layer_norm_output_scale,
358 &integer_lstm_param->layer_norm_output_scale_a,
359 &integer_lstm_param->layer_norm_output_scale_b);
360
361 integer_lstm_param->hidden_zp = intermediate_zp[4];
362
363 // 10000 is used to make sure the kernel logic does not overflow.
364 if (!use_cifg) {
365 integer_lstm_param->input_variance_guard =
366 std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
367 }
368 integer_lstm_param->forget_variance_guard =
369 std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
370 integer_lstm_param->cell_variance_guard =
371 std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
372 integer_lstm_param->output_variance_guard =
373 std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
374
375 return kTfLiteOk;
376 }
377
378 } // namespace
379
380 // Temporary tensors
381 enum TemporaryTensor {
382 kScratchBuffer = 0,
383 kInputQuantized = 1,
384 kOutputStateQuantized = 2,
385 kCellStateQuantized = 3,
386 kInputScalingFactors = 4,
387 kOutputStateScalingFactors = 5,
388 kProductScalingFactors = 6,
389 kRecoveredCellWeights = 7,
390 kAccumScratch = 8,
391 kInputZeroPoints = 9,
392 kOutputStateZeroPoints = 10,
393 kRowSums = 11,
394 kNumTemporaryTensors = 12,
395 };
396
Init(TfLiteContext * context,const char * buffer,size_t length)397 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
398 auto* op_data = new OpData();
399 context->AddTensors(context, kNumTemporaryTensors,
400 &op_data->scratch_tensor_index);
401 return op_data;
402 }
403
Free(TfLiteContext * context,void * buffer)404 void Free(TfLiteContext* context, void* buffer) {
405 delete reinterpret_cast<OpData*>(buffer);
406 }
407
408 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,bool use_layer_norm,bool is_integer)409 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
410 TfLiteNode* node, int n_input,
411 int n_output, int n_cell,
412 bool use_layer_norm, bool is_integer) {
413 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
414
415 // Making sure clipping parameters have valid values.
416 // == 0 means no clipping
417 // > 0 means clipping
418 TF_LITE_ENSURE(context, params->cell_clip >= 0);
419 TF_LITE_ENSURE(context, params->proj_clip >= 0);
420
421 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
422 context, node, lstm::full::kInputToInputWeightsTensor);
423 if (input_to_input_weights != nullptr) {
424 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
425 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
426 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
427 }
428
429 const TfLiteTensor* input_to_forget_weights;
430 TF_LITE_ENSURE_OK(
431 context,
432 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
433 &input_to_forget_weights));
434 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
435 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
436 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
437
438 const TfLiteTensor* input_to_cell_weights;
439 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
440 lstm::full::kInputToCellWeightsTensor,
441 &input_to_cell_weights));
442 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
443 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
444 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
445
446 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
447 context, node, lstm::full::kRecurrentToInputWeightsTensor);
448 if (recurrent_to_input_weights != nullptr) {
449 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
450 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
451 n_cell);
452 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
453 n_output);
454 }
455
456 const TfLiteTensor* recurrent_to_forget_weights;
457 TF_LITE_ENSURE_OK(
458 context,
459 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
460 &recurrent_to_forget_weights));
461 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
462 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
463 n_cell);
464 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
465 n_output);
466
467 const TfLiteTensor* recurrent_to_cell_weights;
468 TF_LITE_ENSURE_OK(
469 context,
470 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
471 &recurrent_to_cell_weights));
472 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
473 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
474 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
475 n_output);
476
477 // We make sure the input-gate's parameters are either both present (regular
478 // LSTM) or not at all (CIFG-LSTM).
479 const bool cifg_weights_all_or_none =
480 ((input_to_input_weights != nullptr) &&
481 (recurrent_to_input_weights != nullptr)) ||
482 ((input_to_input_weights == nullptr) &&
483 (recurrent_to_input_weights == nullptr));
484 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
485
486 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
487 context, node, lstm::full::kCellToInputWeightsTensor);
488 if (cell_to_input_weights != nullptr) {
489 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
490 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
491 TF_LITE_ENSURE_TYPES_EQ(
492 context, cell_to_input_weights->type,
493 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
494 }
495
496 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
497 context, node, lstm::full::kCellToForgetWeightsTensor);
498 if (cell_to_forget_weights != nullptr) {
499 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
500 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
501 TF_LITE_ENSURE_TYPES_EQ(
502 context, cell_to_forget_weights->type,
503 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
504 }
505
506 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
507 context, node, lstm::full::kCellToOutputWeightsTensor);
508 if (cell_to_output_weights != nullptr) {
509 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
510 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
511 TF_LITE_ENSURE_TYPES_EQ(
512 context, cell_to_output_weights->type,
513 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
514 }
515
516 // Making sure the peephole weights are there all or none.
517 const bool use_cifg = (input_to_input_weights == nullptr);
518 const bool peephole_weights_all_or_none =
519 ((cell_to_input_weights != nullptr || use_cifg) &&
520 (cell_to_forget_weights != nullptr) &&
521 (cell_to_output_weights != nullptr)) ||
522 ((cell_to_input_weights == nullptr) &&
523 (cell_to_forget_weights == nullptr) &&
524 (cell_to_output_weights == nullptr));
525 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
526
527 // Make sure the input gate bias is present only when not a CIFG-LSTM.
528 const TfLiteTensor* input_gate_bias =
529 GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
530 if (use_cifg) {
531 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
532 } else {
533 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
534 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
535 if (is_integer) {
536 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
537 } else {
538 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
539 }
540 }
541
542 const TfLiteTensor* forget_gate_bias;
543 TF_LITE_ENSURE_OK(
544 context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
545 &forget_gate_bias));
546 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
547 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
548 if (is_integer) {
549 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
550 } else {
551 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
552 }
553
554 const TfLiteTensor* cell_gate_bias;
555 TF_LITE_ENSURE_OK(context,
556 GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
557 &cell_gate_bias));
558 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
559 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
560 if (is_integer) {
561 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
562 } else {
563 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
564 }
565
566 const TfLiteTensor* output_gate_bias;
567 TF_LITE_ENSURE_OK(
568 context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
569 &output_gate_bias));
570 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
571 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
572 if (is_integer) {
573 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
574 } else {
575 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
576 }
577
578 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
579 context, node, lstm::full::kProjectionWeightsTensor);
580 if (projection_weights != nullptr) {
581 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
582 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
583 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
584 }
585
586 const TfLiteTensor* projection_bias =
587 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
588 if (projection_bias != nullptr) {
589 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
590 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
591 if (is_integer) {
592 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
593 } else {
594 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
595 }
596 }
597
598 // Making sure the projection tensors are consistent:
599 // 1) If projection weight is not present, then projection bias should not be
600 // present.
601 // 2) If projection weight is present, then projection bias is optional.
602 // TODO(ghodrat): make sure this is correct.
603 const bool projecton_tensors_consistent =
604 ((projection_weights != nullptr) || (projection_bias == nullptr));
605 TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
606
607 if (use_layer_norm) {
608 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
609 context, node, lstm::full::kInputLayerNormCoefficientsTensor);
610 if (use_cifg) {
611 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
612 } else {
613 TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
614 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
615 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
616 n_cell);
617 if (is_integer) {
618 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
619 kTfLiteInt16);
620 } else {
621 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
622 kTfLiteFloat32);
623 }
624 }
625
626 const TfLiteTensor* forget_layer_norm_coefficients;
627 TF_LITE_ENSURE_OK(
628 context, GetInputSafe(context, node,
629 lstm::full::kForgetLayerNormCoefficientsTensor,
630 &forget_layer_norm_coefficients));
631 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
632 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
633 n_cell);
634 if (is_integer) {
635 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
636 kTfLiteInt16);
637 } else {
638 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
639 kTfLiteFloat32);
640 }
641
642 const TfLiteTensor* cell_layer_norm_coefficients;
643 TF_LITE_ENSURE_OK(context,
644 GetInputSafe(context, node,
645 lstm::full::kCellLayerNormCoefficientsTensor,
646 &cell_layer_norm_coefficients));
647 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
648 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
649 n_cell);
650 if (is_integer) {
651 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
652 kTfLiteInt16);
653 } else {
654 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
655 kTfLiteFloat32);
656 }
657
658 const TfLiteTensor* output_layer_norm_coefficients;
659 TF_LITE_ENSURE_OK(
660 context, GetInputSafe(context, node,
661 lstm::full::kOutputLayerNormCoefficientsTensor,
662 &output_layer_norm_coefficients));
663 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
664 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
665 n_cell);
666 if (is_integer) {
667 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
668 kTfLiteInt16);
669 } else {
670 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
671 kTfLiteFloat32);
672 }
673 }
674
675 return kTfLiteOk;
676 }
677
PrecomputeZeroPointTimesWeightWithBias(TfLiteContext * context,int32_t zero_point,const TfLiteTensor * weight_tensor,const TfLiteTensor * bias_tensor,std::unique_ptr<int32_t[]> * output)678 TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
679 TfLiteContext* context, int32_t zero_point,
680 const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
681 std::unique_ptr<int32_t[]>* output) {
682 if (weight_tensor == nullptr) {
683 return kTfLiteOk;
684 }
685
686 const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
687 TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
688 const int row = weight_shape.Dims(0);
689 const int col = weight_shape.Dims(1);
690 output->reset(new int32_t[row]);
691 if (bias_tensor == nullptr) {
692 memset(output->get(), 0, row * sizeof(int32_t));
693 } else {
694 const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
695 memcpy(output->get(), bias, row * sizeof(int32_t));
696 }
697 if (zero_point != 0) {
698 const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
699 tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
700 output->get());
701 }
702 return kTfLiteOk;
703 }
704
PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext * context,OpData * op_data,TfLiteNode * node)705 TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
706 OpData* op_data,
707 TfLiteNode* node) {
708 const TfLiteTensor* input;
709 TF_LITE_ENSURE_OK(
710 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
711 const TfLiteTensor* output_state =
712 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
713 TF_LITE_ENSURE(context, output_state != nullptr);
714
715 const int32_t input_zero_point = -input->params.zero_point;
716 const int32_t output_state_zero_point = -output_state->params.zero_point;
717
718 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
719 context, node, lstm::full::kInputToInputWeightsTensor);
720 const TfLiteTensor* input_to_forget_weights;
721 TF_LITE_ENSURE_OK(
722 context,
723 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
724 &input_to_forget_weights));
725 const TfLiteTensor* input_to_cell_weights;
726 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
727 lstm::full::kInputToCellWeightsTensor,
728 &input_to_cell_weights));
729 const TfLiteTensor* input_to_output_weights;
730 TF_LITE_ENSURE_OK(
731 context,
732 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
733 &input_to_output_weights));
734
735 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
736 context, node, lstm::full::kRecurrentToInputWeightsTensor);
737 const TfLiteTensor* recurrent_to_forget_weights;
738 TF_LITE_ENSURE_OK(
739 context,
740 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
741 &recurrent_to_forget_weights));
742 const TfLiteTensor* recurrent_to_cell_weights;
743 TF_LITE_ENSURE_OK(
744 context,
745 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
746 &recurrent_to_cell_weights));
747 const TfLiteTensor* recurrent_to_output_weights;
748 TF_LITE_ENSURE_OK(
749 context,
750 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
751 &recurrent_to_output_weights));
752
753 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
754 context, node, lstm::full::kProjectionWeightsTensor);
755 const TfLiteTensor* projection_bias =
756 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
757
758 lstm_eval::IntegerLstmParameter* integer_lstm_params =
759 &op_data->integer_lstm_param;
760
761 const TfLiteTensor* intermediate =
762 &context->tensors[node->intermediates->data[4]];
763 const auto* params =
764 static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
765 const int32_t hidden_zp = params->zero_point->data[0];
766
767 // Get bias and perform zero point calculation.
768 // When there is layer normalization, the gate bias does not apply to matmul
769 // directly:
770 // y = ln(w * x + w * r + w * c) + b.
771 const bool is_layer_norm = op_data->use_layer_norm;
772
773 // Forget gate.
774 const TfLiteTensor* forget_gate_bias =
775 is_layer_norm
776 ? nullptr
777 : GetInput(context, node, lstm::full::kForgetGateBiasTensor);
778 TF_LITE_ENSURE_OK(
779 context,
780 PrecomputeZeroPointTimesWeightWithBias(
781 context, input_zero_point, input_to_forget_weights, forget_gate_bias,
782 &(integer_lstm_params->input_to_forget_effective_bias)));
783
784 TF_LITE_ENSURE_OK(
785 context,
786 PrecomputeZeroPointTimesWeightWithBias(
787 context, output_state_zero_point, recurrent_to_forget_weights,
788 nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
789
790 // Modulation gate.
791 const TfLiteTensor* cell_gate_bias =
792 is_layer_norm ? nullptr
793 : GetInput(context, node, lstm::full::kCellGateBiasTensor);
794 TF_LITE_ENSURE_OK(
795 context,
796 PrecomputeZeroPointTimesWeightWithBias(
797 context, input_zero_point, input_to_cell_weights, cell_gate_bias,
798 &(integer_lstm_params->input_to_cell_effective_bias)));
799 TF_LITE_ENSURE_OK(
800 context,
801 PrecomputeZeroPointTimesWeightWithBias(
802 context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
803 &(integer_lstm_params->recurrent_to_cell_effective_bias)));
804
805 // Output gate.
806 const TfLiteTensor* output_gate_bias =
807 is_layer_norm
808 ? nullptr
809 : GetInput(context, node, lstm::full::kOutputGateBiasTensor);
810 TF_LITE_ENSURE_OK(
811 context,
812 PrecomputeZeroPointTimesWeightWithBias(
813 context, input_zero_point, input_to_output_weights, output_gate_bias,
814 &(integer_lstm_params->input_to_output_effective_bias)));
815
816 TF_LITE_ENSURE_OK(
817 context,
818 PrecomputeZeroPointTimesWeightWithBias(
819 context, output_state_zero_point, recurrent_to_output_weights,
820 nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
821
822 // Input gate. The calculation is only meaningful for non-cifg case.
823 const TfLiteTensor* input_gate_bias =
824 is_layer_norm ? nullptr
825 : GetInput(context, node, lstm::full::kInputGateBiasTensor);
826 TF_LITE_ENSURE_OK(
827 context,
828 PrecomputeZeroPointTimesWeightWithBias(
829 context, input_zero_point, input_to_input_weights, input_gate_bias,
830 &(integer_lstm_params->input_to_input_effective_bias)));
831 TF_LITE_ENSURE_OK(
832 context,
833 PrecomputeZeroPointTimesWeightWithBias(
834 context, output_state_zero_point, recurrent_to_input_weights, nullptr,
835 &(integer_lstm_params->recurrent_to_input_effective_bias)));
836
837 // Projection bias. The calculation is only meaningful for with projection.
838 TF_LITE_ENSURE_OK(context,
839 PrecomputeZeroPointTimesWeightWithBias(
840 context, hidden_zp, projection_weights, projection_bias,
841 &(integer_lstm_params->projection_effective_bias)));
842 return kTfLiteOk;
843 }
844
845 // Resize the output and state tensors based on the sizes of the input tensors.
846 // Allocate a temporary scratch tensor. Also check that the sizes of the input
847 // tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)848 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
849 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
850 const int scratch_tensor_index = op_data->scratch_tensor_index;
851
852 // Check we have all the inputs and outputs we need.
853 bool use_layer_norm = false;
854 if (node->inputs->size == 24) {
855 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
856 context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
857 if (forget_layer_norm_coefficients == nullptr) {
858 use_layer_norm = false;
859 } else {
860 use_layer_norm = true;
861 }
862 } else if (node->inputs->size == 20) {
863 // This is deprecated and is only kept here for backward compatibility.
864 use_layer_norm = false;
865 } else {
866 context->ReportError(
867 context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
868 node->inputs->size);
869 return kTfLiteError;
870 }
871 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
872 op_data->use_layer_norm = use_layer_norm;
873
874 // Inferring batch size, number of outputs and sequence length and
875 // number of cells from the input tensors.
876 const TfLiteTensor* input;
877 TF_LITE_ENSURE_OK(
878 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
879 const bool is_integer = input->type == kTfLiteInt8;
880 TF_LITE_ENSURE(context, input->dims->size > 1);
881 const auto* params =
882 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
883 node->builtin_data);
884 const bool time_major = params->time_major;
885 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
886 const int n_input = input->dims->data[2];
887
888 const TfLiteTensor* input_to_output_weights;
889 TF_LITE_ENSURE_OK(
890 context,
891 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
892 &input_to_output_weights));
893 const int n_cell = input_to_output_weights->dims->data[0];
894 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
895 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
896
897 const TfLiteTensor* recurrent_to_output_weights;
898 TF_LITE_ENSURE_OK(
899 context,
900 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
901 &recurrent_to_output_weights));
902 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
903 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
904 n_cell);
905 const int n_output = recurrent_to_output_weights->dims->data[1];
906
907 // Check that input tensor dimensions matches with each other.
908 TF_LITE_ENSURE_OK(
909 context, CheckInputTensorDimensions(context, node, n_input, n_output,
910 n_cell, use_layer_norm, is_integer));
911
912 // Get the pointer to output, output_state and cell_state buffer tensors.
913 TfLiteTensor* output;
914 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
915 lstm::full::kOutputTensor, &output));
916
917 TfLiteTensor* output_state =
918 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
919 TF_LITE_ENSURE(context, output_state != nullptr);
920 TfLiteTensor* cell_state =
921 GetVariableInput(context, node, lstm::full::kCellStateTensor);
922 TF_LITE_ENSURE(context, cell_state != nullptr);
923
924 // Check the shape of input state tensors.
925 // These tensor may be 1D or 2D. It's fine as long as the total size is
926 // correct.
927 TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
928 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
929
930 // Resize the output tensors.
931 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
932 output_size->data[input->dims->size - 1] = n_output;
933 TF_LITE_ENSURE_OK(context,
934 context->ResizeTensor(context, output, output_size));
935
936 if (is_integer) {
937 const int num_intermediate_tensors = node->intermediates->size;
938 TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
939 }
940
941 TfLiteIntArrayFree(node->temporaries);
942 if (IsHybridOp(input, input_to_output_weights)) {
943 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
944 } else if (is_integer) {
945 node->temporaries = TfLiteIntArrayCreate(6);
946 } else {
947 node->temporaries = TfLiteIntArrayCreate(1);
948 }
949 node->temporaries->data[kScratchBuffer] =
950 scratch_tensor_index + kScratchBuffer;
951
952 // Create a scratch buffer tensor.
953 TfLiteTensor* scratch_buffer;
954 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
955 &scratch_buffer));
956 scratch_buffer->type = input->type;
957 scratch_buffer->allocation_type = kTfLiteArenaRw;
958
959 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
960 context, node, lstm::full::kInputToInputWeightsTensor);
961 const bool use_cifg = (input_to_input_weights == nullptr);
962 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
963 scratch_buffer_size->data[0] = n_batch;
964 if (use_cifg) {
965 // Reserving space for Cell, Forget, Output gates
966 scratch_buffer_size->data[1] = n_cell * 3;
967 } else {
968 // Reserving space for Input, Cell, Forget, Output gates
969 scratch_buffer_size->data[1] = n_cell * 4;
970 }
971 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
972 scratch_buffer_size));
973
974 if (IsHybridOp(input, input_to_output_weights)) {
975 op_data->compute_row_sums = true;
976 // Allocate temporary tensors to store quantized values of input,
977 // output_state and cell_state tensors.
978 node->temporaries->data[kInputQuantized] =
979 scratch_tensor_index + kInputQuantized;
980 TfLiteTensor* input_quantized;
981 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
982 &input_quantized));
983 input_quantized->type = input_to_output_weights->type;
984 input_quantized->allocation_type = kTfLiteArenaRw;
985 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
986 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
987 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
988 input_quantized_size));
989 }
990 node->temporaries->data[kOutputStateQuantized] =
991 scratch_tensor_index + kOutputStateQuantized;
992 TfLiteTensor* output_state_quantized;
993 TF_LITE_ENSURE_OK(context,
994 GetTemporarySafe(context, node, kOutputStateQuantized,
995 &output_state_quantized));
996 output_state_quantized->type = input_to_output_weights->type;
997 output_state_quantized->allocation_type = kTfLiteArenaRw;
998 if (!TfLiteIntArrayEqual(output_state_quantized->dims,
999 output_state->dims)) {
1000 TfLiteIntArray* output_state_quantized_size =
1001 TfLiteIntArrayCopy(output_state->dims);
1002 TF_LITE_ENSURE_OK(context,
1003 context->ResizeTensor(context, output_state_quantized,
1004 output_state_quantized_size));
1005 }
1006 node->temporaries->data[kCellStateQuantized] =
1007 scratch_tensor_index + kCellStateQuantized;
1008 TfLiteTensor* cell_state_quantized;
1009 TF_LITE_ENSURE_OK(context,
1010 GetTemporarySafe(context, node, kCellStateQuantized,
1011 &cell_state_quantized));
1012 cell_state_quantized->type = input_to_output_weights->type;
1013 cell_state_quantized->allocation_type = kTfLiteArenaRw;
1014 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1015 TfLiteIntArray* cell_state_quantized_size =
1016 TfLiteIntArrayCopy(cell_state->dims);
1017 TF_LITE_ENSURE_OK(context,
1018 context->ResizeTensor(context, cell_state_quantized,
1019 cell_state_quantized_size));
1020 }
1021
1022 // Allocate temporary tensors to store scaling factors and product scaling
1023 // factors. The latter is a convenience storage which allows to quantize
1024 // a vector once (which produces the scaling factors) and multiply it with
1025 // different matrices (which requires multiplying the scaling factors with
1026 // the scaling factor of the matrix).
1027 node->temporaries->data[kInputScalingFactors] =
1028 op_data->scratch_tensor_index + kInputScalingFactors;
1029 TfLiteTensor* input_sf;
1030 TF_LITE_ENSURE_OK(
1031 context,
1032 GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1033 input_sf->type = kTfLiteFloat32;
1034 input_sf->allocation_type = kTfLiteArenaRw;
1035 int scaling_dims[1] = {n_batch};
1036 if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1037 TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1038 input_sf_size->data[0] = n_batch;
1039 TF_LITE_ENSURE_OK(
1040 context, context->ResizeTensor(context, input_sf, input_sf_size));
1041 }
1042 node->temporaries->data[kOutputStateScalingFactors] =
1043 op_data->scratch_tensor_index + kOutputStateScalingFactors;
1044 TfLiteTensor* output_state_sf;
1045 TF_LITE_ENSURE_OK(
1046 context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1047 &output_state_sf));
1048 output_state_sf->type = kTfLiteFloat32;
1049 output_state_sf->allocation_type = kTfLiteArenaRw;
1050 if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1051 TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1052 output_state_sf_size->data[0] = n_batch;
1053 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1054 output_state_sf_size));
1055 }
1056 node->temporaries->data[kProductScalingFactors] =
1057 scratch_tensor_index + kProductScalingFactors;
1058 TfLiteTensor* prod_scaling_factors;
1059 TF_LITE_ENSURE_OK(context,
1060 GetTemporarySafe(context, node, kProductScalingFactors,
1061 &prod_scaling_factors));
1062 prod_scaling_factors->type = kTfLiteFloat32;
1063 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1064 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1065 scaling_dims)) {
1066 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1067 prod_scaling_factors_size->data[0] = n_batch;
1068 TF_LITE_ENSURE_OK(context,
1069 context->ResizeTensor(context, prod_scaling_factors,
1070 prod_scaling_factors_size));
1071 }
1072
1073 // Allocate a temporary tensor to store the recovered cell weights. Since
1074 // this is used for diagonal matrices, only need to store n_cell values.
1075 node->temporaries->data[kRecoveredCellWeights] =
1076 scratch_tensor_index + kRecoveredCellWeights;
1077 TfLiteTensor* recovered_cell_weights;
1078 TF_LITE_ENSURE_OK(context,
1079 GetTemporarySafe(context, node, kRecoveredCellWeights,
1080 &recovered_cell_weights));
1081 recovered_cell_weights->type = kTfLiteFloat32;
1082 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1083 int recovered_cell_dims[1] = {n_cell};
1084 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1085 recovered_cell_dims)) {
1086 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1087 recovered_cell_weights_size->data[0] = n_cell;
1088 TF_LITE_ENSURE_OK(context,
1089 context->ResizeTensor(context, recovered_cell_weights,
1090 recovered_cell_weights_size));
1091 }
1092
1093 // Allocate a temporary tensor to store the accumulated int32 values.
1094 node->temporaries->data[kAccumScratch] =
1095 scratch_tensor_index + kAccumScratch;
1096 TfLiteTensor* accum_scratch;
1097 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1098 &accum_scratch));
1099 accum_scratch->type = kTfLiteInt32;
1100 accum_scratch->allocation_type = kTfLiteArenaRw;
1101 int accum_scratch_dims[2] = {n_cell, n_batch};
1102 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1103 accum_scratch_dims)) {
1104 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1105 accum_size->data[0] = n_cell;
1106 accum_size->data[1] = n_batch;
1107 TF_LITE_ENSURE_OK(
1108 context, context->ResizeTensor(context, accum_scratch, accum_size));
1109 }
1110 node->temporaries->data[kInputZeroPoints] =
1111 op_data->scratch_tensor_index + kInputZeroPoints;
1112 TfLiteTensor* input_zp;
1113 TF_LITE_ENSURE_OK(
1114 context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1115 input_zp->type = kTfLiteFloat32;
1116 input_zp->allocation_type = kTfLiteArenaRw;
1117 if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1118 TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1119 input_zp_size->data[0] = n_batch;
1120 TF_LITE_ENSURE_OK(
1121 context, context->ResizeTensor(context, input_zp, input_zp_size));
1122 }
1123 node->temporaries->data[kOutputStateZeroPoints] =
1124 op_data->scratch_tensor_index + kOutputStateZeroPoints;
1125 TfLiteTensor* output_state_zp;
1126 TF_LITE_ENSURE_OK(context,
1127 GetTemporarySafe(context, node, kOutputStateZeroPoints,
1128 &output_state_zp));
1129 output_state_zp->type = kTfLiteFloat32;
1130 output_state_zp->allocation_type = kTfLiteArenaRw;
1131 if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1132 TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1133 output_state_zp_size->data[0] = n_batch;
1134 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1135 output_state_zp_size));
1136 }
1137 node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
1138 TfLiteTensor* row_sums;
1139 TF_LITE_ENSURE_OK(context,
1140 GetTemporarySafe(context, node, kRowSums, &row_sums));
1141 row_sums->type = kTfLiteInt32;
1142 row_sums->allocation_type = kTfLiteArenaRwPersistent;
1143 int row_sums_rows = use_cifg ? 6 : 8;
1144 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1145 context, node, lstm::full::kProjectionWeightsTensor);
1146 if (projection_weights != nullptr) {
1147 row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1148 }
1149 int row_sums_dims[2] = {row_sums_rows, n_cell};
1150 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1151 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1152 row_sums_size->data[0] = row_sums_dims[0];
1153 row_sums_size->data[1] = row_sums_dims[1];
1154 TF_LITE_ENSURE_OK(
1155 context, context->ResizeTensor(context, row_sums, row_sums_size));
1156 }
1157 }
1158
1159 if (is_integer) {
1160 // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
1161 // This code path needs 5 intermediate tensors per Op.
1162 // Populate quantization parameters.
1163 PopulateQuantizedLstmParams8x8_16(context, node,
1164 &op_data->integer_lstm_param);
1165 // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1166 // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1167 // buffer with size n_batch * n_cell.
1168 //
1169 // Handle cifg case as well, which might save one buffer.
1170 for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1171 node->temporaries->data[scratch_index] =
1172 op_data->scratch_tensor_index + scratch_index;
1173 TfLiteTensor* scratch_tensor;
1174 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index,
1175 &scratch_tensor));
1176
1177 scratch_tensor->type = kTfLiteInt16;
1178 if (scratch_index == 4) {
1179 scratch_tensor->type = kTfLiteInt8;
1180 } else if (scratch_index == 5) {
1181 scratch_tensor->type = kTfLiteInt32;
1182 }
1183
1184 scratch_tensor->allocation_type = kTfLiteArenaRw;
1185 const int scratch_dimension[2] = {n_batch, n_cell};
1186 if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1187 scratch_dimension)) {
1188 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1189 scratch_buffer_size->data[0] = n_batch;
1190 scratch_buffer_size->data[1] = n_cell;
1191 TF_LITE_ENSURE_OK(context,
1192 context->ResizeTensor(context, scratch_tensor,
1193 scratch_buffer_size));
1194 }
1195 }
1196
1197 // Populate precomputed zp * weight.
1198 TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1199 context, op_data, node));
1200 }
1201
1202 return kTfLiteOk;
1203 }
1204
Eval(TfLiteContext * context,TfLiteNode * node)1205 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1206 const auto* params =
1207 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
1208 node->builtin_data);
1209 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1210 const bool use_layer_norm = op_data->use_layer_norm;
1211 const bool time_major = params->time_major;
1212 const TfLiteTensor* input;
1213 TF_LITE_ENSURE_OK(
1214 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
1215
1216 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
1217 context, node, lstm::full::kInputToInputWeightsTensor);
1218 const TfLiteTensor* input_to_forget_weights;
1219 TF_LITE_ENSURE_OK(
1220 context,
1221 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
1222 &input_to_forget_weights));
1223 const TfLiteTensor* input_to_cell_weights;
1224 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
1225 lstm::full::kInputToCellWeightsTensor,
1226 &input_to_cell_weights));
1227 const TfLiteTensor* input_to_output_weights;
1228 TF_LITE_ENSURE_OK(
1229 context,
1230 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
1231 &input_to_output_weights));
1232
1233 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1234 context, node, lstm::full::kRecurrentToInputWeightsTensor);
1235 const TfLiteTensor* recurrent_to_forget_weights;
1236 TF_LITE_ENSURE_OK(
1237 context,
1238 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
1239 &recurrent_to_forget_weights));
1240 const TfLiteTensor* recurrent_to_cell_weights;
1241 TF_LITE_ENSURE_OK(
1242 context,
1243 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
1244 &recurrent_to_cell_weights));
1245 const TfLiteTensor* recurrent_to_output_weights;
1246 TF_LITE_ENSURE_OK(
1247 context,
1248 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
1249 &recurrent_to_output_weights));
1250
1251 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
1252 context, node, lstm::full::kCellToInputWeightsTensor);
1253 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
1254 context, node, lstm::full::kCellToForgetWeightsTensor);
1255 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
1256 context, node, lstm::full::kCellToOutputWeightsTensor);
1257
1258 const TfLiteTensor* input_gate_bias =
1259 GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
1260 const TfLiteTensor* forget_gate_bias;
1261 TF_LITE_ENSURE_OK(
1262 context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
1263 &forget_gate_bias));
1264 const TfLiteTensor* cell_gate_bias;
1265 TF_LITE_ENSURE_OK(context,
1266 GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
1267 &cell_gate_bias));
1268 const TfLiteTensor* output_gate_bias;
1269 TF_LITE_ENSURE_OK(
1270 context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
1271 &output_gate_bias));
1272
1273 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1274 context, node, lstm::full::kProjectionWeightsTensor);
1275 const TfLiteTensor* projection_bias =
1276 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
1277
1278 TfLiteTensor* output_state =
1279 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
1280 TFLITE_DCHECK(output_state != nullptr);
1281 TfLiteTensor* cell_state =
1282 GetVariableInput(context, node, lstm::full::kCellStateTensor);
1283 TFLITE_DCHECK(cell_state != nullptr);
1284
1285 const TfLiteTensor* input_layer_norm_coefficients =
1286 use_layer_norm
1287 ? GetOptionalInputTensor(
1288 context, node, lstm::full::kInputLayerNormCoefficientsTensor)
1289 : nullptr;
1290 const TfLiteTensor* forget_layer_norm_coefficients =
1291 use_layer_norm ? GetInput(context, node,
1292 lstm::full::kForgetLayerNormCoefficientsTensor)
1293 : nullptr;
1294 const TfLiteTensor* cell_layer_norm_coefficients =
1295 use_layer_norm ? GetInput(context, node,
1296 lstm::full::kCellLayerNormCoefficientsTensor)
1297 : nullptr;
1298 const TfLiteTensor* output_layer_norm_coefficients =
1299 use_layer_norm ? GetInput(context, node,
1300 lstm::full::kOutputLayerNormCoefficientsTensor)
1301 : nullptr;
1302
1303 TfLiteTensor* output;
1304 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
1305 lstm::full::kOutputTensor, &output));
1306
1307 // Copy out the LSTM specific params so they can be passed in the function.
1308 TfLiteLSTMParams lstm_params;
1309 lstm_params.activation = params->activation;
1310 lstm_params.cell_clip = params->cell_clip;
1311 lstm_params.proj_clip = params->proj_clip;
1312 lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
1313
1314 switch (input_to_output_weights->type) {
1315 case kTfLiteFloat32: {
1316 // Index the scratch buffers pointers to the global scratch buffer.
1317 TfLiteTensor* scratch_buffer;
1318 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1319 &scratch_buffer));
1320 return lstm_eval::EvalFloat(
1321 input, input_to_input_weights, input_to_forget_weights,
1322 input_to_cell_weights, input_to_output_weights,
1323 recurrent_to_input_weights, recurrent_to_forget_weights,
1324 recurrent_to_cell_weights, recurrent_to_output_weights,
1325 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1326 input_layer_norm_coefficients, forget_layer_norm_coefficients,
1327 cell_layer_norm_coefficients, output_layer_norm_coefficients,
1328 /*aux_input=*/nullptr,
1329 /*aux_input_to_input_weights=*/nullptr,
1330 /*aux_input_to_forget_weights=*/nullptr,
1331 /*aux_input_to_cell_weights=*/nullptr,
1332 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1333 forget_gate_bias, cell_gate_bias, output_gate_bias,
1334 projection_weights, projection_bias, &lstm_params,
1335 /*forward_sequence=*/true, time_major,
1336 /*output_offset=*/0, scratch_buffer, output_state, cell_state,
1337 output);
1338 }
1339 case kTfLiteUInt8:
1340 case kTfLiteInt8: {
1341 const bool is_hybrid = input->type == kTfLiteFloat32;
1342 if (is_hybrid) {
1343 // Index the scratch buffers pointers to the global scratch buffer.
1344 TfLiteTensor* scratch_buffer;
1345 TF_LITE_ENSURE_OK(
1346 context,
1347 GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer));
1348
1349 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1350 TfLiteTensor* row_sums;
1351 TF_LITE_ENSURE_OK(context,
1352 GetTemporarySafe(context, node, kRowSums, &row_sums));
1353 const int row_sums_size = row_sums->dims->data[0];
1354 return lstm_eval::EvalHybrid(
1355 input, input_to_input_weights,
1356 /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
1357 /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
1358 /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
1359 /*input_to_output_weights_ledger*/ nullptr,
1360 recurrent_to_input_weights,
1361 /*recurrent_to_input_weights_ledger*/ nullptr,
1362 recurrent_to_forget_weights,
1363 /*recurrent_to_forget_weights_ledger*/ nullptr,
1364 recurrent_to_cell_weights,
1365 /*recurrent_to_cell_weights_ledger*/ nullptr,
1366 recurrent_to_output_weights,
1367 /*recurrent_to_output_weights_ledger*/ nullptr,
1368 cell_to_input_weights, cell_to_forget_weights,
1369 cell_to_output_weights, input_layer_norm_coefficients,
1370 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1371 output_layer_norm_coefficients,
1372 /*aux_input=*/nullptr,
1373 /*aux_input_to_input_weights=*/nullptr,
1374 /*aux_input_to_forget_weights=*/nullptr,
1375 /*aux_input_to_cell_weights=*/nullptr,
1376 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1377 forget_gate_bias, cell_gate_bias, output_gate_bias,
1378 projection_weights, /*projection_weights_ledger*/ nullptr,
1379 projection_bias, &lstm_params,
1380 /*forward_sequence=*/true, time_major,
1381 /*output_offset=*/0, scratch_buffer,
1382 GetTemporary(context, node, kInputScalingFactors),
1383 /*aux_input_sf=*/nullptr,
1384 GetTemporary(context, node, kOutputStateScalingFactors),
1385 GetTemporary(context, node, kProductScalingFactors),
1386 GetTemporary(context, node, kRecoveredCellWeights),
1387 GetTemporary(context, node, kInputQuantized),
1388 /*aux_input_quantized=*/nullptr,
1389 GetTemporary(context, node, kOutputStateQuantized),
1390 GetTemporary(context, node, kCellStateQuantized), output_state,
1391 cell_state, GetTemporary(context, node, kAccumScratch), output,
1392 GetTemporary(context, node, kInputZeroPoints),
1393 /*aux_input_zp=*/nullptr,
1394 GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
1395 row_sums_size, &op_data->compute_row_sums,
1396 CpuBackendContext::GetFromContext(context));
1397 } else {
1398 TfLiteTensor* scratch0;
1399 TF_LITE_ENSURE_OK(context,
1400 GetTemporarySafe(context, node, 0, &scratch0));
1401 TfLiteTensor* scratch1;
1402 TF_LITE_ENSURE_OK(context,
1403 GetTemporarySafe(context, node, 1, &scratch1));
1404 TfLiteTensor* scratch2;
1405 TF_LITE_ENSURE_OK(context,
1406 GetTemporarySafe(context, node, 2, &scratch2));
1407 TfLiteTensor* scratch3;
1408 TF_LITE_ENSURE_OK(context,
1409 GetTemporarySafe(context, node, 3, &scratch3));
1410 TfLiteTensor* scratch4;
1411 TF_LITE_ENSURE_OK(context,
1412 GetTemporarySafe(context, node, 4, &scratch4));
1413 TfLiteTensor* scratch5;
1414 TF_LITE_ENSURE_OK(context,
1415 GetTemporarySafe(context, node, 5, &scratch5));
1416 return lstm_eval::EvalInteger8x8_16(
1417 input, input_to_input_weights, input_to_forget_weights,
1418 input_to_cell_weights, input_to_output_weights,
1419 recurrent_to_input_weights, recurrent_to_forget_weights,
1420 recurrent_to_cell_weights, recurrent_to_output_weights,
1421 cell_to_input_weights, cell_to_forget_weights,
1422 cell_to_output_weights, input_layer_norm_coefficients,
1423 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1424 output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
1425 cell_gate_bias, output_gate_bias, projection_weights,
1426 projection_bias, &lstm_params, /*forward_sequence=*/true,
1427 time_major, &op_data->integer_lstm_param, output_state, cell_state,
1428 output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
1429 CpuBackendContext::GetFromContext(context));
1430 }
1431 }
1432 default:
1433 TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1434 TfLiteTypeGetName(input_to_output_weights->type));
1435 return kTfLiteError;
1436 }
1437 return kTfLiteOk;
1438 }
1439 } // namespace unidirectional_sequence_lstm
1440
Register_UNIDIRECTIONAL_SEQUENCE_LSTM()1441 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
1442 static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
1443 unidirectional_sequence_lstm::Free,
1444 unidirectional_sequence_lstm::Prepare,
1445 unidirectional_sequence_lstm::Eval};
1446 return &r;
1447 }
1448
1449 } // namespace builtin
1450 } // namespace ops
1451 } // namespace tflite
1452