• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32 
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace custom {
37 
38 namespace {
39 
LstmStepWithAuxInput(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * output_ptr,Logger * logger,const std::vector<int> & intemediate_tensor_indexes,ErrorReporter * error_reporter)40 inline void LstmStepWithAuxInput(
41     const float* input_ptr, const float* input_to_input_weights_ptr,
42     const float* input_to_forget_weights_ptr,
43     const float* input_to_cell_weights_ptr,
44     const float* input_to_output_weights_ptr, const float* aux_input_ptr,
45     const float* aux_input_to_input_weights_ptr,
46     const float* aux_input_to_forget_weights_ptr,
47     const float* aux_input_to_cell_weights_ptr,
48     const float* aux_input_to_output_weights_ptr,
49     const float* recurrent_to_input_weights_ptr,
50     const float* recurrent_to_forget_weights_ptr,
51     const float* recurrent_to_cell_weights_ptr,
52     const float* recurrent_to_output_weights_ptr,
53     const float* cell_to_input_weights_ptr,
54     const float* cell_to_forget_weights_ptr,
55     const float* cell_to_output_weights_ptr,
56     const float* input_layer_norm_coefficients_ptr,
57     const float* forget_layer_norm_coefficients_ptr,
58     const float* cell_layer_norm_coefficients_ptr,
59     const float* output_layer_norm_coefficients_ptr,
60     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
61     const float* cell_bias_ptr, const float* output_gate_bias_ptr,
62     const float* projection_weights_ptr, const float* projection_bias_ptr,
63     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
64     int n_aux_input, int n_output, int output_batch_leading_dim,
65     float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
66     float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
67     float* output_ptr, Logger* logger,
68     const std::vector<int>& intemediate_tensor_indexes,
69     ErrorReporter* error_reporter) {
70   // Since we have already checked that weights are all there or none, we can
71   // check the existence of only one to the get the condition.
72   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
73   const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
74   const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
75 
76   // Initialize scratch buffers with bias for regular lstm or initialize with
77   // zero for layer norm lstm.
78   if (use_layer_norm) {
79     if (!use_cifg) {
80       std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
81     }
82     std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
83     std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
84     std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
85   } else {
86     if (!use_cifg) {
87       tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
88                                             n_batch, input_gate_scratch);
89     }
90     tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
91                                           forget_gate_scratch);
92     tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
93                                           cell_scratch);
94     tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
95                                           output_gate_scratch);
96   }
97 
98   // For each batch and cell: compute input_weight * input.
99   if (!use_cifg) {
100     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
101         input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch,
102         input_gate_scratch);
103   }
104 
105   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
106       input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch,
107       forget_gate_scratch);
108   tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
109                                                     n_cell, n_input, input_ptr,
110                                                     n_batch, cell_scratch);
111   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
112       input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
113       output_gate_scratch);
114 
115   {
116     // calibration.
117     if (!use_cifg) {
118       logger->LogTensorValue(intemediate_tensor_indexes[1], input_gate_scratch,
119                              n_cell * n_batch, error_reporter);
120     }
121     logger->LogTensorValue(intemediate_tensor_indexes[4], forget_gate_scratch,
122                            n_cell * n_batch, error_reporter);
123     logger->LogTensorValue(intemediate_tensor_indexes[7], cell_scratch,
124                            n_cell * n_batch, error_reporter);
125     logger->LogTensorValue(intemediate_tensor_indexes[10], output_gate_scratch,
126                            n_cell * n_batch, error_reporter);
127   }
128 
129   // If auxiliary input is available then compute aux_input_weight * aux_input
130   if (aux_input_ptr != nullptr) {
131     if (!use_cifg) {
132       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
133           aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
134           n_batch, input_gate_scratch);
135     }
136 
137     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
138         aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
139         n_batch, forget_gate_scratch);
140     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
141         aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
142         n_batch, cell_scratch);
143     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
144         aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
145         n_batch, output_gate_scratch);
146   }
147 
148   // For each batch and cell: compute recurrent_weight * output_state.
149   if (!use_cifg) {
150     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
151         recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
152         n_batch, input_gate_scratch);
153   }
154   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
155       recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
156       n_batch, forget_gate_scratch);
157   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
158       recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
159       n_batch, cell_scratch);
160   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
161       recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
162       n_batch, output_gate_scratch);
163   {
164     // calibrition.
165     if (!use_cifg) {
166       std::vector<float> temp_input(n_batch * n_cell);
167       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
168           recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
169           n_batch, temp_input.data());
170       logger->LogTensorValue(intemediate_tensor_indexes[2], temp_input.data(),
171                              n_cell * n_batch, error_reporter);
172     }
173     std::vector<float> temp_forget(n_batch * n_cell);
174     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
175         recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
176         n_batch, temp_forget.data());
177     logger->LogTensorValue(intemediate_tensor_indexes[5], temp_forget.data(),
178                            n_cell * n_batch, error_reporter);
179 
180     std::vector<float> temp_cell(n_batch * n_cell);
181     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
182         recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
183         n_batch, temp_cell.data());
184 
185     logger->LogTensorValue(intemediate_tensor_indexes[8], temp_cell.data(),
186                            n_cell * n_batch, error_reporter);
187 
188     std::vector<float> temp_output(n_batch * n_cell);
189     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
190         recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
191         n_batch, temp_output.data());
192     logger->LogTensorValue(intemediate_tensor_indexes[11], temp_output.data(),
193                            n_cell * n_batch, error_reporter);
194   }
195 
196   // For each batch and cell: update input gate.
197   if (!use_cifg) {
198     if (use_peephole) {
199       tensor_utils::VectorBatchVectorCwiseProductAccumulate(
200           cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
201           input_gate_scratch);
202     }
203     if (use_layer_norm) {
204       logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
205                              n_cell * n_batch, error_reporter);
206       tensor_utils::MeanStddevNormalization(
207           input_gate_scratch, input_gate_scratch, n_cell, n_batch);
208       tensor_utils::VectorBatchVectorCwiseProduct(
209           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
210           n_batch, input_gate_scratch);
211       tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
212                                          input_gate_scratch);
213     }
214     tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
215                                        input_gate_scratch);
216   }
217 
218   // For each batch and cell: update forget gate.
219   if (use_peephole) {
220     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
221         cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
222         forget_gate_scratch);
223   }
224   if (use_layer_norm) {
225     logger->LogTensorValue(intemediate_tensor_indexes[3], forget_gate_scratch,
226                            n_cell * n_batch, error_reporter);
227     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
228                                           forget_gate_scratch, n_cell, n_batch);
229     tensor_utils::VectorBatchVectorCwiseProduct(
230         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
231         n_batch, forget_gate_scratch);
232     tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
233                                        forget_gate_scratch);
234   }
235   tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
236                                      forget_gate_scratch);
237 
238   // For each batch and cell: update the cell.
239   tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
240                                          n_batch * n_cell, cell_state_ptr);
241   if (use_layer_norm) {
242     logger->LogTensorValue(intemediate_tensor_indexes[6], cell_scratch,
243                            n_cell * n_batch, error_reporter);
244     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
245                                           n_batch);
246     tensor_utils::VectorBatchVectorCwiseProduct(
247         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
248         cell_scratch);
249     tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
250                                        cell_scratch);
251   }
252   tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
253                                         params->activation, cell_scratch);
254   if (use_cifg) {
255     tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
256                              forget_gate_scratch);
257     tensor_utils::VectorVectorCwiseProductAccumulate(
258         cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
259   } else {
260     tensor_utils::VectorVectorCwiseProductAccumulate(
261         cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
262   }
263   if (params->cell_clip > 0.0) {
264     tensor_utils::CwiseClipping(cell_state_ptr, n_batch * n_cell,
265                                 params->cell_clip);
266   }
267 
268   // For each batch and cell: update the output gate.
269   if (use_peephole) {
270     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
271         cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
272         output_gate_scratch);
273   }
274   if (use_layer_norm) {
275     logger->LogTensorValue(intemediate_tensor_indexes[9], output_gate_scratch,
276                            n_cell * n_batch, error_reporter);
277     tensor_utils::MeanStddevNormalization(output_gate_scratch,
278                                           output_gate_scratch, n_cell, n_batch);
279     tensor_utils::VectorBatchVectorCwiseProduct(
280         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
281         n_batch, output_gate_scratch);
282     tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
283                                        output_gate_scratch);
284   }
285   tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
286                                      output_gate_scratch);
287   tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
288                                         params->activation, cell_scratch);
289   tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
290                                          n_batch * n_cell, output_gate_scratch);
291 
292   const bool use_projection_weight = (projection_weights_ptr != nullptr);
293   const bool use_projection_bias = (projection_bias_ptr != nullptr);
294 
295   // For each batch: update the projection and output_state. Note that since
296   // the output batch rows may not be contiguous (output_batch_leading_dim !=
297   // n_output), we unroll batched operations.
298   if (use_projection_weight) {
299     if (use_projection_bias) {
300       for (int k = 0; k < n_batch; k++) {
301         std::copy_n(projection_bias_ptr, n_output,
302                     output_ptr + k * output_batch_leading_dim);
303       }
304     } else {
305       for (int k = 0; k < n_batch; k++) {
306         std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
307       }
308     }
309     for (int k = 0; k < n_batch; k++) {
310       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
311           projection_weights_ptr, n_output, n_cell,
312           output_gate_scratch + k * n_cell,
313           /*n_batch=*/1, output_ptr + k * output_batch_leading_dim);
314       if (params->proj_clip > 0.0) {
315         tensor_utils::CwiseClipping(output_ptr + k * output_batch_leading_dim,
316                                     n_output, params->proj_clip);
317       }
318     }
319   } else {
320     for (int k = 0; k < n_batch; k++) {
321       std::copy_n(output_gate_scratch + k * n_output, n_output,
322                   output_ptr + k * output_batch_leading_dim);
323     }
324   }
325   for (int k = 0; k < n_batch; k++) {
326     std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
327                 output_state_ptr + k * n_output);
328   }
329 }
330 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * activation_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intemediate_tensor_indexes,ErrorReporter * error_reporter)331 TfLiteStatus EvalFloat(
332     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
333     const TfLiteTensor* input_to_forget_weights,
334     const TfLiteTensor* input_to_cell_weights,
335     const TfLiteTensor* input_to_output_weights,
336     const TfLiteTensor* recurrent_to_input_weights,
337     const TfLiteTensor* recurrent_to_forget_weights,
338     const TfLiteTensor* recurrent_to_cell_weights,
339     const TfLiteTensor* recurrent_to_output_weights,
340     const TfLiteTensor* cell_to_input_weights,
341     const TfLiteTensor* cell_to_forget_weights,
342     const TfLiteTensor* cell_to_output_weights,
343     const TfLiteTensor* input_layer_norm_coefficients,
344     const TfLiteTensor* forget_layer_norm_coefficients,
345     const TfLiteTensor* cell_layer_norm_coefficients,
346     const TfLiteTensor* output_layer_norm_coefficients,
347     const TfLiteTensor* aux_input,
348     const TfLiteTensor* aux_input_to_input_weights,
349     const TfLiteTensor* aux_input_to_forget_weights,
350     const TfLiteTensor* aux_input_to_cell_weights,
351     const TfLiteTensor* aux_input_to_output_weights,
352     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
353     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
354     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
355     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
356     int output_offset, TfLiteTensor* scratch_buffer,
357     TfLiteTensor* activation_state, TfLiteTensor* cell_state,
358     TfLiteTensor* output, Logger* logger,
359     const std::vector<int>& intemediate_tensor_indexes,
360     ErrorReporter* error_reporter) {
361   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
362   int max_time, n_batch;
363   if (input->dims->size == 3) {
364     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
365     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
366   } else {
367     max_time = 1;
368     n_batch = input->dims->data[0];
369   }
370   const int n_input = input->dims->data[input->dims->size - 1];
371   const int aux_input_size =
372       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
373 
374   // n_cell and n_output will be the same size when there is no projection.
375   const int n_cell = input_to_output_weights->dims->data[0];
376   const int n_output = recurrent_to_output_weights->dims->data[1];
377 
378   // Since we have already checked that weights are all there or none, we can
379   // check the existence of only one to the get the condition.
380   const bool use_cifg = (input_to_input_weights == nullptr);
381 
382   // Index the scratch buffers pointers to the global scratch buffer.
383   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
384   float* input_gate_scratch = nullptr;
385   float* cell_scratch = nullptr;
386   float* forget_gate_scratch = nullptr;
387   float* output_gate_scratch = nullptr;
388   if (use_cifg) {
389     cell_scratch = scratch_buffer_ptr;
390     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
391     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
392   } else {
393     input_gate_scratch = scratch_buffer_ptr;
394     cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
395     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
396     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
397   }
398 
399   const int output_batch_leading_dim =
400       output->dims->data[output->dims->size - 1];
401   if (time_major) {
402     // Loop through the sequence.
403     const int input_step = n_batch * n_input;
404     const int output_step = n_batch * output_batch_leading_dim;
405     for (int t = 0; t < max_time; t++) {
406       // If this is the forward_sequence, step forward, otherwise step
407       // backwards.
408       const int t_rel = forward_sequence ? t : max_time - t - 1;
409       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
410       const float* aux_input_ptr = nullptr;
411       if (aux_input) {
412         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
413       }
414       float* output_ptr_time =
415           GetTensorData<float>(output) + t_rel * output_step + output_offset;
416 
417       LstmStepWithAuxInput(
418           input_ptr, GetTensorData<float>(input_to_input_weights),
419           GetTensorData<float>(input_to_forget_weights),
420           GetTensorData<float>(input_to_cell_weights),
421           GetTensorData<float>(input_to_output_weights), aux_input_ptr,
422           GetTensorData<float>(aux_input_to_input_weights),
423           GetTensorData<float>(aux_input_to_forget_weights),
424           GetTensorData<float>(aux_input_to_cell_weights),
425           GetTensorData<float>(aux_input_to_output_weights),
426           GetTensorData<float>(recurrent_to_input_weights),
427           GetTensorData<float>(recurrent_to_forget_weights),
428           GetTensorData<float>(recurrent_to_cell_weights),
429           GetTensorData<float>(recurrent_to_output_weights),
430           GetTensorData<float>(cell_to_input_weights),
431           GetTensorData<float>(cell_to_forget_weights),
432           GetTensorData<float>(cell_to_output_weights),
433           GetTensorData<float>(input_layer_norm_coefficients),
434           GetTensorData<float>(forget_layer_norm_coefficients),
435           GetTensorData<float>(cell_layer_norm_coefficients),
436           GetTensorData<float>(output_layer_norm_coefficients),
437           GetTensorData<float>(input_gate_bias),
438           GetTensorData<float>(forget_gate_bias),
439           GetTensorData<float>(cell_bias),
440           GetTensorData<float>(output_gate_bias),
441           GetTensorData<float>(projection_weights),
442           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
443           n_input, aux_input_size, n_output, output_batch_leading_dim,
444           GetTensorData<float>(activation_state),
445           GetTensorData<float>(cell_state), input_gate_scratch,
446           forget_gate_scratch, cell_scratch, output_gate_scratch,
447           output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
448     }
449   } else {
450     for (int b = 0; b < n_batch; b++) {
451       const int input_step = n_input;
452       const int output_step = output_batch_leading_dim;
453       for (int t = 0; t < max_time; t++) {
454         // If this is the forward_sequence, step forward, otherwise step
455         // backwards.
456         const int t_rel = forward_sequence ? t : max_time - t - 1;
457         const int time_offset = b * max_time + t_rel;
458         const float* input_ptr =
459             GetTensorData<float>(input) + time_offset * input_step;
460         const float* aux_input_ptr = nullptr;
461         if (aux_input) {
462           aux_input_ptr =
463               GetTensorData<float>(aux_input) + time_offset * input_step;
464         }
465         float* output_ptr = GetTensorData<float>(output) +
466                             time_offset * output_step + output_offset;
467 
468         // Offset the {activation,cell}_state pointers to the right batch.
469         float* activation_state_ptr = GetTensorData<float>(activation_state) +
470                                       b * output_batch_leading_dim;
471         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
472         // Offset the scratch pointers to the right batch.
473         float* input_gate_scratch_ptr =
474             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
475         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
476         float* cell_scratch_ptr = cell_scratch + b * n_cell;
477         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
478 
479         LstmStepWithAuxInput(
480             input_ptr, GetTensorData<float>(input_to_input_weights),
481             GetTensorData<float>(input_to_forget_weights),
482             GetTensorData<float>(input_to_cell_weights),
483             GetTensorData<float>(input_to_output_weights), aux_input_ptr,
484             GetTensorData<float>(aux_input_to_input_weights),
485             GetTensorData<float>(aux_input_to_forget_weights),
486             GetTensorData<float>(aux_input_to_cell_weights),
487             GetTensorData<float>(aux_input_to_output_weights),
488             GetTensorData<float>(recurrent_to_input_weights),
489             GetTensorData<float>(recurrent_to_forget_weights),
490             GetTensorData<float>(recurrent_to_cell_weights),
491             GetTensorData<float>(recurrent_to_output_weights),
492             GetTensorData<float>(cell_to_input_weights),
493             GetTensorData<float>(cell_to_forget_weights),
494             GetTensorData<float>(cell_to_output_weights),
495             GetTensorData<float>(input_layer_norm_coefficients),
496             GetTensorData<float>(forget_layer_norm_coefficients),
497             GetTensorData<float>(cell_layer_norm_coefficients),
498             GetTensorData<float>(output_layer_norm_coefficients),
499             GetTensorData<float>(input_gate_bias),
500             GetTensorData<float>(forget_gate_bias),
501             GetTensorData<float>(cell_bias),
502             GetTensorData<float>(output_gate_bias),
503             GetTensorData<float>(projection_weights),
504             GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
505             n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
506             activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
507             forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
508             output_ptr, logger, intemediate_tensor_indexes, error_reporter);
509       }
510     }
511   }
512   return kTfLiteOk;
513 }
514 
515 struct OpData {
516   // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
517   // inputs).
518   // Please note the 20-input full kernel is deprecated and only kept
519   // here for backward compatibility.
520   TfLiteLSTMKernelType kernel_type;
521 
522   // If the lstm is layer norm.
523   bool use_layer_norm;
524 
525   // These fields are only used by full kernel.
526   int scratch_tensor_index;
527 };
528 
529 // Resize the output, state tensors based on the sizes of the input tensors.
530 // Allocate a temporary scratch tensor. Also check that the sizes of the input
531 // tensors match each other.
lstm_eval(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)532 TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
533                        ErrorReporter* error_reporter) {
534   const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
535 
536   const TfLiteTensor* input =
537       GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
538 
539   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
540       context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
541   const TfLiteTensor* input_to_forget_weights = GetInput(
542       context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
543   const TfLiteTensor* input_to_cell_weights = GetInput(
544       context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
545   const TfLiteTensor* input_to_output_weights = GetInput(
546       context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
547 
548   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
549       context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
550   const TfLiteTensor* recurrent_to_forget_weights = GetInput(
551       context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
552   const TfLiteTensor* recurrent_to_cell_weights = GetInput(
553       context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
554   const TfLiteTensor* recurrent_to_output_weights = GetInput(
555       context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
556 
557   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
558       context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
559   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
560       context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
561   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
562       context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
563 
564   const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
565       context, node,
566       ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
567   const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
568       context, node,
569       ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
570   const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
571       context, node,
572       ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
573   const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
574       context, node,
575       ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
576 
577   const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
578       context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
579   const TfLiteTensor* forget_gate_bias =
580       GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
581   const TfLiteTensor* cell_bias =
582       GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
583   const TfLiteTensor* output_gate_bias =
584       GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
585 
586   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
587       context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
588   const TfLiteTensor* projection_bias = GetOptionalInputTensor(
589       context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
590 
591   // Index the scratch buffers pointers to the global scratch buffer.
592   TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
593 
594   TfLiteTensor* activation_state = GetVariableInput(
595       context, node, ops::builtin::lstm::full::kOutputStateTensor);
596   TF_LITE_ENSURE(context, activation_state != nullptr);
597   TfLiteTensor* cell_state = GetVariableInput(
598       context, node, ops::builtin::lstm::full::kCellStateTensor);
599   TF_LITE_ENSURE(context, cell_state != nullptr);
600 
601   TfLiteTensor* output =
602       GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
603 
604   std::vector<int> intemediate_tensor_indexes(node->intermediates->size);
605   for (int i = 0; i < node->intermediates->size; ++i) {
606     intemediate_tensor_indexes[i] = node->intermediates->data[i];
607   }
608 
609   switch (input_to_output_weights->type) {
610     case kTfLiteFloat32: {
611       return EvalFloat(
612           input, input_to_input_weights, input_to_forget_weights,
613           input_to_cell_weights, input_to_output_weights,
614           recurrent_to_input_weights, recurrent_to_forget_weights,
615           recurrent_to_cell_weights, recurrent_to_output_weights,
616           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
617           input_layer_norm_coefficients, forget_layer_norm_coefficients,
618           cell_layer_norm_coefficients, output_layer_norm_coefficients,
619           /*aux_input=*/nullptr,
620           /*aux_input_to_input_weights=*/nullptr,
621           /*aux_input_to_forget_weights=*/nullptr,
622           /*aux_input_to_cell_weights=*/nullptr,
623           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
624           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
625           projection_bias, params, /*forward_sequence=*/true,
626           /*time_major=*/true,
627           /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
628           output, logger, intemediate_tensor_indexes, error_reporter);
629     }
630     case kTfLiteUInt8:
631     case kTfLiteInt8:
632     default:
633       printf("Error. Only float model can be calibrated\n");
634       return kTfLiteError;
635   }
636   return kTfLiteOk;
637 }
638 }  // namespace
639 
lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)640 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
641                                  Logger* logger,
642                                  ErrorReporter* error_reporter) {
643   return lstm_eval(context, node, logger, error_reporter);
644 }
645 
646 }  // namespace custom
647 }  // namespace calibration
648 }  // namespace optimize
649 }  // namespace tflite
650