• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32 
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace builtin {
37 
38 namespace {
39 
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,Logger * logger,int intermediate_tensor_index,ErrorReporter * error_reporter)40 inline void CalculateLstmGateFloat(
41     const float* input, const float* input_to_gate_weights,
42     const float* aux_input, const float* aux_input_to_gate_weights,
43     const float* output_state, const float* recurrent_to_gate_weights,
44     const float* cell_state, const float* cell_to_gate_weights,
45     const float* layer_norm_coefficients, const float* gate_bias,
46     const int n_batch, const int n_input, const int n_aux_input,
47     const int n_output, const int n_cell,
48     const TfLiteFusedActivation activation, float* gate,
49     const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
50     Logger* logger, int intermediate_tensor_index,
51     ErrorReporter* error_reporter) {
52   const bool use_peephole = (cell_to_gate_weights != nullptr);
53   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
54 
55   // Initialize scratch buffers with bias for regular lstm or initialize with
56   // zero for layer norm lstm.
57   if (use_layer_norm) {
58     std::fill_n(gate, n_cell * n_batch, 0.0f);
59   } else {
60     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
61   }
62   // For each batch and cell: compute input_weight * input.
63   // Skip if input is all zeros.
64   if (!is_input_all_zeros) {
65     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
66         input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
67   }
68   // For each batch and cell: compute aux_input_weight * aux_input.
69   // Skip if auxiliary input is not available or all zeros.
70   if (!is_aux_input_all_zeros) {
71     tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
72                                                       n_cell, n_aux_input,
73                                                       aux_input, n_batch, gate);
74   }
75   // For each batch and cell: compute recurrent_weight * output_state.
76   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
77       recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
78   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
79   if (use_peephole) {
80     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
81         cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
82   }
83   // Do layer normalization (if layer norm LSTM)
84   if (use_layer_norm) {
85     logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch,
86                            error_reporter);
87 
88     tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
89     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
90                                                 gate, n_batch, gate);
91     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
92   }
93   // Apply activation
94   tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
95                                         gate);
96 }
97 
98 // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
99 // kernels/lstm_eval.cc, make that public and remove this.
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)100 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
101                          const float* input_gate, float* forget_gate,
102                          const float* cell_gate, bool use_cifg, float clip) {
103   tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
104                                          n_batch * n_cell, cell_state);
105 
106   if (use_cifg) {
107     // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
108     // scratch, as input_gate array is not allocated in this case. (Be careful
109     // not to write to the scratch before reading the forget gate data.)
110     float* scratch = forget_gate;
111     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
112     tensor_utils::VectorVectorCwiseProductAccumulate(
113         cell_gate, scratch, n_batch * n_cell, cell_state);
114   } else {
115     tensor_utils::VectorVectorCwiseProductAccumulate(
116         cell_gate, input_gate, n_batch * n_cell, cell_state);
117   }
118   if (clip > 0.0f) {
119     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
120   }
121 }
122 
CalculateLstmOutputCalibration(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch,Logger * logger,int intermediate_tensor_index,ErrorReporter * error_reporter)123 void CalculateLstmOutputCalibration(
124     int n_batch, int n_cell, int n_output, const float* cell_state,
125     const float* output_gate, TfLiteFusedActivation activation,
126     const float* projection_weights, const float* projection_bias,
127     const float proj_clip, float* output_state, float* scratch, Logger* logger,
128     int intermediate_tensor_index, ErrorReporter* error_reporter) {
129   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
130                                         activation, scratch);
131   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
132                                          scratch);
133 
134   logger->LogTensorValue(intermediate_tensor_index, scratch, n_cell * n_batch,
135                          error_reporter);
136 
137   const bool use_projection = (projection_weights != nullptr);
138   const bool use_projection_bias = (projection_bias != nullptr);
139 
140   if (use_projection) {
141     if (use_projection_bias) {
142       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
143                                             output_state);
144     } else {
145       std::fill_n(output_state, n_batch * n_output, 0.0f);
146     }
147     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
148         projection_weights, n_output, n_cell, scratch, n_batch, output_state);
149     if (proj_clip > 0.0f) {
150       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
151     }
152   } else {
153     std::copy_n(scratch, n_batch * n_output, output_state);
154   }
155 }
156 
LstmStepCalibration(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,ErrorReporter * error_reporter)157 inline void LstmStepCalibration(
158     const float* input_ptr, const float* input_to_input_weights_ptr,
159     const float* input_to_forget_weights_ptr,
160     const float* input_to_cell_weights_ptr,
161     const float* input_to_output_weights_ptr, const float* aux_input_ptr,
162     const float* aux_input_to_input_weights_ptr,
163     const float* aux_input_to_forget_weights_ptr,
164     const float* aux_input_to_cell_weights_ptr,
165     const float* aux_input_to_output_weights_ptr,
166     const float* recurrent_to_input_weights_ptr,
167     const float* recurrent_to_forget_weights_ptr,
168     const float* recurrent_to_cell_weights_ptr,
169     const float* recurrent_to_output_weights_ptr,
170     const float* cell_to_input_weights_ptr,
171     const float* cell_to_forget_weights_ptr,
172     const float* cell_to_output_weights_ptr,
173     const float* input_layer_norm_coefficients_ptr,
174     const float* forget_layer_norm_coefficients_ptr,
175     const float* cell_layer_norm_coefficients_ptr,
176     const float* output_layer_norm_coefficients_ptr,
177     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
178     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
179     const float* projection_weights_ptr, const float* projection_bias_ptr,
180     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
181     int n_aux_input, int n_output, int output_batch_leading_dim,
182     float* output_state_ptr, float* cell_state_ptr, float* scratch0,
183     float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
184     Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
185     ErrorReporter* error_reporter) {
186   ruy::profiler::ScopeLabel label("LstmStepCalibration");
187   // Since we have already checked that weights are all there or none, we can
188   // check the existence of only one to the get the condition.
189   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
190 
191   // Make named scratch buffers.
192   float* input_gate_scratch = scratch0;
193   float* forget_gate_scratch = scratch1;
194   float* cell_gate_scratch = scratch2;
195   float* output_gate_scratch = scratch3;
196 
197   // Check if inputs are all zeros so we can skip some computations.
198   const bool is_input_all_zeros =
199       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
200   const bool is_aux_input_all_zeros =
201       (aux_input_ptr == nullptr ||
202        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
203   if (!use_cifg) {
204     // Calculate the input gate. (If not CIFG.)
205     CalculateLstmGateFloat(
206         input_ptr, input_to_input_weights_ptr, aux_input_ptr,
207         aux_input_to_input_weights_ptr, output_state_ptr,
208         recurrent_to_input_weights_ptr, cell_state_ptr,
209         cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
210         input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
211         /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
212         is_input_all_zeros, is_aux_input_all_zeros, logger,
213         intermediate_tensor_indexes[0], error_reporter);
214   }
215   // Calculate the forget gate.
216   CalculateLstmGateFloat(
217       input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
218       aux_input_to_forget_weights_ptr, output_state_ptr,
219       recurrent_to_forget_weights_ptr, cell_state_ptr,
220       cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
221       forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
222       /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
223       is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
224       error_reporter);
225   // Calculate the cell update gate.
226   CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
227                          aux_input_to_cell_weights_ptr, output_state_ptr,
228                          recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
229                          /*cell_to_gate_weights=*/nullptr,
230                          cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
231                          n_batch, n_input, n_aux_input, n_output, n_cell,
232                          params->activation, cell_gate_scratch,
233                          is_input_all_zeros, is_aux_input_all_zeros, logger,
234                          intermediate_tensor_indexes[2], error_reporter);
235   // Update the cell state.
236   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
237                       forget_gate_scratch, cell_gate_scratch, use_cifg,
238                       params->cell_clip);
239   // Calculate output gate.
240   CalculateLstmGateFloat(
241       input_ptr, input_to_output_weights_ptr, aux_input_ptr,
242       aux_input_to_output_weights_ptr, output_state_ptr,
243       recurrent_to_output_weights_ptr, cell_state_ptr,
244       cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
245       output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
246       /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
247       is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
248       error_reporter);
249   // Update the output state.
250   CalculateLstmOutputCalibration(
251       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
252       params->activation, projection_weights_ptr, projection_bias_ptr,
253       params->proj_clip, output_state_ptr, scratch2, logger,
254       intermediate_tensor_indexes[4], error_reporter);
255   // Copy output state to the output. Note that the output's rows may not be
256   // contiguous (output_batch_leading_dim != n_output).
257   for (int b = 0; b < n_batch; b++) {
258     std::copy_n(output_state_ptr + b * n_output, n_output,
259                 output_ptr + b * output_batch_leading_dim);
260   }
261 }
262 
EvalCalibration(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,ErrorReporter * error_reporter)263 TfLiteStatus EvalCalibration(
264     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
265     const TfLiteTensor* input_to_forget_weights,
266     const TfLiteTensor* input_to_cell_weights,
267     const TfLiteTensor* input_to_output_weights,
268     const TfLiteTensor* recurrent_to_input_weights,
269     const TfLiteTensor* recurrent_to_forget_weights,
270     const TfLiteTensor* recurrent_to_cell_weights,
271     const TfLiteTensor* recurrent_to_output_weights,
272     const TfLiteTensor* cell_to_input_weights,
273     const TfLiteTensor* cell_to_forget_weights,
274     const TfLiteTensor* cell_to_output_weights,
275     const TfLiteTensor* input_layer_norm_coefficients,
276     const TfLiteTensor* forget_layer_norm_coefficients,
277     const TfLiteTensor* cell_layer_norm_coefficients,
278     const TfLiteTensor* output_layer_norm_coefficients,
279     const TfLiteTensor* aux_input,
280     const TfLiteTensor* aux_input_to_input_weights,
281     const TfLiteTensor* aux_input_to_forget_weights,
282     const TfLiteTensor* aux_input_to_cell_weights,
283     const TfLiteTensor* aux_input_to_output_weights,
284     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
285     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
286     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
287     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
288     int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
289     TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
290     const std::vector<int>& intermediate_tensor_indexes,
291     ErrorReporter* error_reporter) {
292   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
293   int max_time, n_batch;
294   if (input->dims->size == 3) {
295     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
296     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
297   } else {
298     max_time = 1;
299     n_batch = input->dims->data[0];
300   }
301   const int n_input = input->dims->data[input->dims->size - 1];
302   const int aux_input_size =
303       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
304 
305   // n_cell and n_output will be the same size when there is no projection.
306   const int n_cell = input_to_output_weights->dims->data[0];
307   const int n_output = recurrent_to_output_weights->dims->data[1];
308 
309   // Since we have already checked that weights are all there or none, we can
310   // check the existence of only one to the get the condition.
311   const bool use_cifg = (input_to_input_weights == nullptr);
312 
313   // Index the scratch buffers pointers to the global scratch buffer.
314   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
315   float* input_gate_scratch = nullptr;
316   float* cell_gate_scratch = nullptr;
317   float* forget_gate_scratch = nullptr;
318   float* output_gate_scratch = nullptr;
319   if (use_cifg) {
320     cell_gate_scratch = scratch_buffer_ptr;
321     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
322     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
323   } else {
324     input_gate_scratch = scratch_buffer_ptr;
325     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
326     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
327     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
328   }
329 
330   const int output_batch_leading_dim =
331       output->dims->data[output->dims->size - 1];
332   if (time_major) {
333     // Loop through the sequence.
334     const int input_step = n_batch * n_input;
335     const int output_step = n_batch * output_batch_leading_dim;
336     for (int t = 0; t < max_time; t++) {
337       // If this is the forward_sequence, step forward, otherwise step
338       // backwards.
339       const int t_rel = forward_sequence ? t : max_time - t - 1;
340       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
341       const float* aux_input_ptr = nullptr;
342       if (aux_input) {
343         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
344       }
345       float* output_ptr_time =
346           GetTensorData<float>(output) + t_rel * output_step + output_offset;
347 
348       LstmStepCalibration(
349           input_ptr, GetTensorData<float>(input_to_input_weights),
350           GetTensorData<float>(input_to_forget_weights),
351           GetTensorData<float>(input_to_cell_weights),
352           GetTensorData<float>(input_to_output_weights), aux_input_ptr,
353           GetTensorData<float>(aux_input_to_input_weights),
354           GetTensorData<float>(aux_input_to_forget_weights),
355           GetTensorData<float>(aux_input_to_cell_weights),
356           GetTensorData<float>(aux_input_to_output_weights),
357           GetTensorData<float>(recurrent_to_input_weights),
358           GetTensorData<float>(recurrent_to_forget_weights),
359           GetTensorData<float>(recurrent_to_cell_weights),
360           GetTensorData<float>(recurrent_to_output_weights),
361           GetTensorData<float>(cell_to_input_weights),
362           GetTensorData<float>(cell_to_forget_weights),
363           GetTensorData<float>(cell_to_output_weights),
364           GetTensorData<float>(input_layer_norm_coefficients),
365           GetTensorData<float>(forget_layer_norm_coefficients),
366           GetTensorData<float>(cell_layer_norm_coefficients),
367           GetTensorData<float>(output_layer_norm_coefficients),
368           GetTensorData<float>(input_gate_bias),
369           GetTensorData<float>(forget_gate_bias),
370           GetTensorData<float>(cell_gate_bias),
371           GetTensorData<float>(output_gate_bias),
372           GetTensorData<float>(projection_weights),
373           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
374           n_input, aux_input_size, n_output, output_batch_leading_dim,
375           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
376           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
377           output_gate_scratch, output_ptr_time, logger,
378           intermediate_tensor_indexes, error_reporter);
379     }
380   } else {
381     for (int b = 0; b < n_batch; b++) {
382       const int input_step = n_input;
383       const int output_step = output_batch_leading_dim;
384       for (int t = 0; t < max_time; t++) {
385         // If this is the forward_sequence, step forward, otherwise step
386         // backwards.
387         const int t_rel = forward_sequence ? t : max_time - t - 1;
388         const int time_offset = b * max_time + t_rel;
389         const float* input_ptr =
390             GetTensorData<float>(input) + time_offset * input_step;
391         const float* aux_input_ptr = nullptr;
392         if (aux_input) {
393           aux_input_ptr =
394               GetTensorData<float>(aux_input) + time_offset * input_step;
395         }
396         float* output_ptr = GetTensorData<float>(output) +
397                             time_offset * output_step + output_offset;
398 
399         // Offset the {output,cell}_state pointers to the right batch.
400         float* output_state_ptr =
401             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
402         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
403         // Offset the scratch pointers to the right batch.
404         float* input_gate_scratch_ptr =
405             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
406         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
407         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
408         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
409 
410         LstmStepCalibration(
411             input_ptr, GetTensorData<float>(input_to_input_weights),
412             GetTensorData<float>(input_to_forget_weights),
413             GetTensorData<float>(input_to_cell_weights),
414             GetTensorData<float>(input_to_output_weights), aux_input_ptr,
415             GetTensorData<float>(aux_input_to_input_weights),
416             GetTensorData<float>(aux_input_to_forget_weights),
417             GetTensorData<float>(aux_input_to_cell_weights),
418             GetTensorData<float>(aux_input_to_output_weights),
419             GetTensorData<float>(recurrent_to_input_weights),
420             GetTensorData<float>(recurrent_to_forget_weights),
421             GetTensorData<float>(recurrent_to_cell_weights),
422             GetTensorData<float>(recurrent_to_output_weights),
423             GetTensorData<float>(cell_to_input_weights),
424             GetTensorData<float>(cell_to_forget_weights),
425             GetTensorData<float>(cell_to_output_weights),
426             GetTensorData<float>(input_layer_norm_coefficients),
427             GetTensorData<float>(forget_layer_norm_coefficients),
428             GetTensorData<float>(cell_layer_norm_coefficients),
429             GetTensorData<float>(output_layer_norm_coefficients),
430             GetTensorData<float>(input_gate_bias),
431             GetTensorData<float>(forget_gate_bias),
432             GetTensorData<float>(cell_gate_bias),
433             GetTensorData<float>(output_gate_bias),
434             GetTensorData<float>(projection_weights),
435             GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
436             n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
437             output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
438             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
439             output_gate_scratch_ptr, output_ptr, logger,
440             intermediate_tensor_indexes, error_reporter);
441       }
442     }
443   }
444   return kTfLiteOk;
445 }
446 
447 struct OpData {
448   // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
449   // inputs).
450   // Please note the 20-input full kernel is deprecated and only kept
451   // here for backward compatibility.
452   TfLiteLSTMKernelType kernel_type;
453 
454   // If the lstm is layer norm.
455   bool use_layer_norm;
456 
457   // These fields are only used by full kernel.
458   int scratch_tensor_index;
459 };
460 
461 // Resize the output, state tensors based on the sizes of the input tensors.
462 // Allocate a temporary scratch tensor. Also check that the sizes of the input
463 // tensors match each other.
lstm_eval(TfLiteContext * context,TfLiteNode * node,LSTMType lstm_type,Logger * logger,ErrorReporter * error_reporter)464 TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
465                        LSTMType lstm_type, Logger* logger,
466                        ErrorReporter* error_reporter) {
467   const TfLiteTensor* input;
468   TF_LITE_ENSURE_OK(
469       context, GetInputSafe(context, node,
470                             ops::builtin::lstm::full::kInputTensor, &input));
471 
472   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
473       context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
474   const TfLiteTensor* input_to_forget_weights;
475   TF_LITE_ENSURE_OK(
476       context,
477       GetInputSafe(context, node,
478                    ops::builtin::lstm::full::kInputToForgetWeightsTensor,
479                    &input_to_forget_weights));
480   const TfLiteTensor* input_to_cell_weights;
481   TF_LITE_ENSURE_OK(
482       context, GetInputSafe(context, node,
483                             ops::builtin::lstm::full::kInputToCellWeightsTensor,
484                             &input_to_cell_weights));
485   const TfLiteTensor* input_to_output_weights;
486   TF_LITE_ENSURE_OK(
487       context,
488       GetInputSafe(context, node,
489                    ops::builtin::lstm::full::kInputToOutputWeightsTensor,
490                    &input_to_output_weights));
491 
492   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
493       context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
494   const TfLiteTensor* recurrent_to_forget_weights;
495   TF_LITE_ENSURE_OK(
496       context,
497       GetInputSafe(context, node,
498                    ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
499                    &recurrent_to_forget_weights));
500   const TfLiteTensor* recurrent_to_cell_weights;
501   TF_LITE_ENSURE_OK(
502       context,
503       GetInputSafe(context, node,
504                    ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
505                    &recurrent_to_cell_weights));
506   const TfLiteTensor* recurrent_to_output_weights;
507   TF_LITE_ENSURE_OK(
508       context,
509       GetInputSafe(context, node,
510                    ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
511                    &recurrent_to_output_weights));
512 
513   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
514       context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
515   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
516       context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
517   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
518       context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
519 
520   const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
521       context, node,
522       ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
523   const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
524       context, node,
525       ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
526   const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
527       context, node,
528       ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
529   const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
530       context, node,
531       ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
532 
533   const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
534       context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
535   const TfLiteTensor* forget_gate_bias;
536   TF_LITE_ENSURE_OK(
537       context, GetInputSafe(context, node,
538                             ops::builtin::lstm::full::kForgetGateBiasTensor,
539                             &forget_gate_bias));
540   const TfLiteTensor* cell_gate_bias;
541   TF_LITE_ENSURE_OK(
542       context,
543       GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
544                    &cell_gate_bias));
545   const TfLiteTensor* output_gate_bias;
546   TF_LITE_ENSURE_OK(
547       context, GetInputSafe(context, node,
548                             ops::builtin::lstm::full::kOutputGateBiasTensor,
549                             &output_gate_bias));
550 
551   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
552       context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
553   const TfLiteTensor* projection_bias = GetOptionalInputTensor(
554       context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
555 
556   // Index the scratch buffers pointers to the global scratch buffer.
557   TfLiteTensor* scratch_buffer;
558   TF_LITE_ENSURE_OK(
559       context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
560 
561   TfLiteTensor* output_state = GetVariableInput(
562       context, node, ops::builtin::lstm::full::kOutputStateTensor);
563   TF_LITE_ENSURE(context, output_state != nullptr);
564   TfLiteTensor* cell_state = GetVariableInput(
565       context, node, ops::builtin::lstm::full::kCellStateTensor);
566   TF_LITE_ENSURE(context, cell_state != nullptr);
567 
568   TfLiteTensor* output;
569   TF_LITE_ENSURE_OK(
570       context, GetOutputSafe(context, node,
571                              ops::builtin::lstm::full::kOutputTensor, &output));
572 
573   std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
574   // LSTM expect 5 intermediate tensors.
575   TF_LITE_ENSURE_EQ(context, node->intermediates->size, 5);
576   for (int i = 0; i < node->intermediates->size; ++i) {
577     intermediate_tensor_indexes[i] = node->intermediates->data[i];
578   }
579 
580   TfLiteLSTMParams lstm_params;
581   bool time_major = true;
582   switch (lstm_type) {
583     case LSTMType::kLSTM: {
584       lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
585       time_major = true;
586       break;
587     }
588     case LSTMType::kUnidirectionalSequenceLSTM: {
589       const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
590           node->builtin_data);
591       // Copy out the LSTM specific params so they can be passed in the
592       // function.
593       lstm_params.activation = params->activation;
594       lstm_params.cell_clip = params->cell_clip;
595       lstm_params.proj_clip = params->proj_clip;
596       lstm_params.asymmetric_quantize_inputs =
597           params->asymmetric_quantize_inputs;
598       time_major = params->time_major;
599       break;
600     }
601     default:
602       return kTfLiteError;
603   }
604 
605   switch (input_to_output_weights->type) {
606     case kTfLiteFloat32: {
607       return EvalCalibration(
608           input, input_to_input_weights, input_to_forget_weights,
609           input_to_cell_weights, input_to_output_weights,
610           recurrent_to_input_weights, recurrent_to_forget_weights,
611           recurrent_to_cell_weights, recurrent_to_output_weights,
612           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
613           input_layer_norm_coefficients, forget_layer_norm_coefficients,
614           cell_layer_norm_coefficients, output_layer_norm_coefficients,
615           /*aux_input=*/nullptr,
616           /*aux_input_to_input_weights=*/nullptr,
617           /*aux_input_to_forget_weights=*/nullptr,
618           /*aux_input_to_cell_weights=*/nullptr,
619           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
620           forget_gate_bias, cell_gate_bias, output_gate_bias,
621           projection_weights, projection_bias, &lstm_params,
622           /*forward_sequence=*/true,
623           /*time_major=*/time_major,
624           /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
625           logger, intermediate_tensor_indexes, error_reporter);
626     }
627     case kTfLiteUInt8:
628     case kTfLiteInt8:
629     default:
630       printf("Error. Only float model can be calibrated\n");
631       return kTfLiteError;
632   }
633   return kTfLiteOk;
634 }
635 }  // namespace
636 
lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)637 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
638                                  Logger* logger,
639                                  ErrorReporter* error_reporter) {
640   return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
641 }
642 
unidirectional_sequence_lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)643 TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
644     TfLiteContext* context, TfLiteNode* node, Logger* logger,
645     ErrorReporter* error_reporter) {
646   return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
647                    error_reporter);
648 }
649 
650 }  // namespace builtin
651 }  // namespace calibration
652 }  // namespace optimize
653 }  // namespace tflite
654