1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace builtin {
37
38 namespace {
39
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,Logger * logger,int intermediate_tensor_index,ErrorReporter * error_reporter)40 inline void CalculateLstmGateFloat(
41 const float* input, const float* input_to_gate_weights,
42 const float* aux_input, const float* aux_input_to_gate_weights,
43 const float* output_state, const float* recurrent_to_gate_weights,
44 const float* cell_state, const float* cell_to_gate_weights,
45 const float* layer_norm_coefficients, const float* gate_bias,
46 const int n_batch, const int n_input, const int n_aux_input,
47 const int n_output, const int n_cell,
48 const TfLiteFusedActivation activation, float* gate,
49 const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
50 Logger* logger, int intermediate_tensor_index,
51 ErrorReporter* error_reporter) {
52 const bool use_peephole = (cell_to_gate_weights != nullptr);
53 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
54
55 // Initialize scratch buffers with bias for regular lstm or initialize with
56 // zero for layer norm lstm.
57 if (use_layer_norm) {
58 std::fill_n(gate, n_cell * n_batch, 0.0f);
59 } else {
60 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
61 }
62 // For each batch and cell: compute input_weight * input.
63 // Skip if input is all zeros.
64 if (!is_input_all_zeros) {
65 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
66 input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
67 }
68 // For each batch and cell: compute aux_input_weight * aux_input.
69 // Skip if auxiliary input is not available or all zeros.
70 if (!is_aux_input_all_zeros) {
71 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
72 n_cell, n_aux_input,
73 aux_input, n_batch, gate);
74 }
75 // For each batch and cell: compute recurrent_weight * output_state.
76 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
77 recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
78 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
79 if (use_peephole) {
80 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
81 cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
82 }
83 // Do layer normalization (if layer norm LSTM)
84 if (use_layer_norm) {
85 logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch,
86 error_reporter);
87
88 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
89 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
90 gate, n_batch, gate);
91 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
92 }
93 // Apply activation
94 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
95 gate);
96 }
97
98 // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
99 // kernels/lstm_eval.cc, make that public and remove this.
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)100 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
101 const float* input_gate, float* forget_gate,
102 const float* cell_gate, bool use_cifg, float clip) {
103 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
104 n_batch * n_cell, cell_state);
105
106 if (use_cifg) {
107 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
108 // scratch, as input_gate array is not allocated in this case. (Be careful
109 // not to write to the scratch before reading the forget gate data.)
110 float* scratch = forget_gate;
111 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
112 tensor_utils::VectorVectorCwiseProductAccumulate(
113 cell_gate, scratch, n_batch * n_cell, cell_state);
114 } else {
115 tensor_utils::VectorVectorCwiseProductAccumulate(
116 cell_gate, input_gate, n_batch * n_cell, cell_state);
117 }
118 if (clip > 0.0f) {
119 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
120 }
121 }
122
CalculateLstmOutputCalibration(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch,Logger * logger,int intermediate_tensor_index,ErrorReporter * error_reporter)123 void CalculateLstmOutputCalibration(
124 int n_batch, int n_cell, int n_output, const float* cell_state,
125 const float* output_gate, TfLiteFusedActivation activation,
126 const float* projection_weights, const float* projection_bias,
127 const float proj_clip, float* output_state, float* scratch, Logger* logger,
128 int intermediate_tensor_index, ErrorReporter* error_reporter) {
129 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
130 activation, scratch);
131 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
132 scratch);
133
134 logger->LogTensorValue(intermediate_tensor_index, scratch, n_cell * n_batch,
135 error_reporter);
136
137 const bool use_projection = (projection_weights != nullptr);
138 const bool use_projection_bias = (projection_bias != nullptr);
139
140 if (use_projection) {
141 if (use_projection_bias) {
142 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
143 output_state);
144 } else {
145 std::fill_n(output_state, n_batch * n_output, 0.0f);
146 }
147 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
148 projection_weights, n_output, n_cell, scratch, n_batch, output_state);
149 if (proj_clip > 0.0f) {
150 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
151 }
152 } else {
153 std::copy_n(scratch, n_batch * n_output, output_state);
154 }
155 }
156
LstmStepCalibration(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,ErrorReporter * error_reporter)157 inline void LstmStepCalibration(
158 const float* input_ptr, const float* input_to_input_weights_ptr,
159 const float* input_to_forget_weights_ptr,
160 const float* input_to_cell_weights_ptr,
161 const float* input_to_output_weights_ptr, const float* aux_input_ptr,
162 const float* aux_input_to_input_weights_ptr,
163 const float* aux_input_to_forget_weights_ptr,
164 const float* aux_input_to_cell_weights_ptr,
165 const float* aux_input_to_output_weights_ptr,
166 const float* recurrent_to_input_weights_ptr,
167 const float* recurrent_to_forget_weights_ptr,
168 const float* recurrent_to_cell_weights_ptr,
169 const float* recurrent_to_output_weights_ptr,
170 const float* cell_to_input_weights_ptr,
171 const float* cell_to_forget_weights_ptr,
172 const float* cell_to_output_weights_ptr,
173 const float* input_layer_norm_coefficients_ptr,
174 const float* forget_layer_norm_coefficients_ptr,
175 const float* cell_layer_norm_coefficients_ptr,
176 const float* output_layer_norm_coefficients_ptr,
177 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
178 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
179 const float* projection_weights_ptr, const float* projection_bias_ptr,
180 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
181 int n_aux_input, int n_output, int output_batch_leading_dim,
182 float* output_state_ptr, float* cell_state_ptr, float* scratch0,
183 float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
184 Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
185 ErrorReporter* error_reporter) {
186 ruy::profiler::ScopeLabel label("LstmStepCalibration");
187 // Since we have already checked that weights are all there or none, we can
188 // check the existence of only one to the get the condition.
189 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
190
191 // Make named scratch buffers.
192 float* input_gate_scratch = scratch0;
193 float* forget_gate_scratch = scratch1;
194 float* cell_gate_scratch = scratch2;
195 float* output_gate_scratch = scratch3;
196
197 // Check if inputs are all zeros so we can skip some computations.
198 const bool is_input_all_zeros =
199 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
200 const bool is_aux_input_all_zeros =
201 (aux_input_ptr == nullptr ||
202 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
203 if (!use_cifg) {
204 // Calculate the input gate. (If not CIFG.)
205 CalculateLstmGateFloat(
206 input_ptr, input_to_input_weights_ptr, aux_input_ptr,
207 aux_input_to_input_weights_ptr, output_state_ptr,
208 recurrent_to_input_weights_ptr, cell_state_ptr,
209 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
210 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
211 /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
212 is_input_all_zeros, is_aux_input_all_zeros, logger,
213 intermediate_tensor_indexes[0], error_reporter);
214 }
215 // Calculate the forget gate.
216 CalculateLstmGateFloat(
217 input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
218 aux_input_to_forget_weights_ptr, output_state_ptr,
219 recurrent_to_forget_weights_ptr, cell_state_ptr,
220 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
221 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
222 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
223 is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
224 error_reporter);
225 // Calculate the cell update gate.
226 CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
227 aux_input_to_cell_weights_ptr, output_state_ptr,
228 recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
229 /*cell_to_gate_weights=*/nullptr,
230 cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
231 n_batch, n_input, n_aux_input, n_output, n_cell,
232 params->activation, cell_gate_scratch,
233 is_input_all_zeros, is_aux_input_all_zeros, logger,
234 intermediate_tensor_indexes[2], error_reporter);
235 // Update the cell state.
236 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
237 forget_gate_scratch, cell_gate_scratch, use_cifg,
238 params->cell_clip);
239 // Calculate output gate.
240 CalculateLstmGateFloat(
241 input_ptr, input_to_output_weights_ptr, aux_input_ptr,
242 aux_input_to_output_weights_ptr, output_state_ptr,
243 recurrent_to_output_weights_ptr, cell_state_ptr,
244 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
245 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
246 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
247 is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
248 error_reporter);
249 // Update the output state.
250 CalculateLstmOutputCalibration(
251 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
252 params->activation, projection_weights_ptr, projection_bias_ptr,
253 params->proj_clip, output_state_ptr, scratch2, logger,
254 intermediate_tensor_indexes[4], error_reporter);
255 // Copy output state to the output. Note that the output's rows may not be
256 // contiguous (output_batch_leading_dim != n_output).
257 for (int b = 0; b < n_batch; b++) {
258 std::copy_n(output_state_ptr + b * n_output, n_output,
259 output_ptr + b * output_batch_leading_dim);
260 }
261 }
262
EvalCalibration(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,ErrorReporter * error_reporter)263 TfLiteStatus EvalCalibration(
264 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
265 const TfLiteTensor* input_to_forget_weights,
266 const TfLiteTensor* input_to_cell_weights,
267 const TfLiteTensor* input_to_output_weights,
268 const TfLiteTensor* recurrent_to_input_weights,
269 const TfLiteTensor* recurrent_to_forget_weights,
270 const TfLiteTensor* recurrent_to_cell_weights,
271 const TfLiteTensor* recurrent_to_output_weights,
272 const TfLiteTensor* cell_to_input_weights,
273 const TfLiteTensor* cell_to_forget_weights,
274 const TfLiteTensor* cell_to_output_weights,
275 const TfLiteTensor* input_layer_norm_coefficients,
276 const TfLiteTensor* forget_layer_norm_coefficients,
277 const TfLiteTensor* cell_layer_norm_coefficients,
278 const TfLiteTensor* output_layer_norm_coefficients,
279 const TfLiteTensor* aux_input,
280 const TfLiteTensor* aux_input_to_input_weights,
281 const TfLiteTensor* aux_input_to_forget_weights,
282 const TfLiteTensor* aux_input_to_cell_weights,
283 const TfLiteTensor* aux_input_to_output_weights,
284 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
285 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
286 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
287 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
288 int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
289 TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
290 const std::vector<int>& intermediate_tensor_indexes,
291 ErrorReporter* error_reporter) {
292 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
293 int max_time, n_batch;
294 if (input->dims->size == 3) {
295 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
296 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
297 } else {
298 max_time = 1;
299 n_batch = input->dims->data[0];
300 }
301 const int n_input = input->dims->data[input->dims->size - 1];
302 const int aux_input_size =
303 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
304
305 // n_cell and n_output will be the same size when there is no projection.
306 const int n_cell = input_to_output_weights->dims->data[0];
307 const int n_output = recurrent_to_output_weights->dims->data[1];
308
309 // Since we have already checked that weights are all there or none, we can
310 // check the existence of only one to the get the condition.
311 const bool use_cifg = (input_to_input_weights == nullptr);
312
313 // Index the scratch buffers pointers to the global scratch buffer.
314 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
315 float* input_gate_scratch = nullptr;
316 float* cell_gate_scratch = nullptr;
317 float* forget_gate_scratch = nullptr;
318 float* output_gate_scratch = nullptr;
319 if (use_cifg) {
320 cell_gate_scratch = scratch_buffer_ptr;
321 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
322 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
323 } else {
324 input_gate_scratch = scratch_buffer_ptr;
325 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
326 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
327 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
328 }
329
330 const int output_batch_leading_dim =
331 output->dims->data[output->dims->size - 1];
332 if (time_major) {
333 // Loop through the sequence.
334 const int input_step = n_batch * n_input;
335 const int output_step = n_batch * output_batch_leading_dim;
336 for (int t = 0; t < max_time; t++) {
337 // If this is the forward_sequence, step forward, otherwise step
338 // backwards.
339 const int t_rel = forward_sequence ? t : max_time - t - 1;
340 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
341 const float* aux_input_ptr = nullptr;
342 if (aux_input) {
343 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
344 }
345 float* output_ptr_time =
346 GetTensorData<float>(output) + t_rel * output_step + output_offset;
347
348 LstmStepCalibration(
349 input_ptr, GetTensorData<float>(input_to_input_weights),
350 GetTensorData<float>(input_to_forget_weights),
351 GetTensorData<float>(input_to_cell_weights),
352 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
353 GetTensorData<float>(aux_input_to_input_weights),
354 GetTensorData<float>(aux_input_to_forget_weights),
355 GetTensorData<float>(aux_input_to_cell_weights),
356 GetTensorData<float>(aux_input_to_output_weights),
357 GetTensorData<float>(recurrent_to_input_weights),
358 GetTensorData<float>(recurrent_to_forget_weights),
359 GetTensorData<float>(recurrent_to_cell_weights),
360 GetTensorData<float>(recurrent_to_output_weights),
361 GetTensorData<float>(cell_to_input_weights),
362 GetTensorData<float>(cell_to_forget_weights),
363 GetTensorData<float>(cell_to_output_weights),
364 GetTensorData<float>(input_layer_norm_coefficients),
365 GetTensorData<float>(forget_layer_norm_coefficients),
366 GetTensorData<float>(cell_layer_norm_coefficients),
367 GetTensorData<float>(output_layer_norm_coefficients),
368 GetTensorData<float>(input_gate_bias),
369 GetTensorData<float>(forget_gate_bias),
370 GetTensorData<float>(cell_gate_bias),
371 GetTensorData<float>(output_gate_bias),
372 GetTensorData<float>(projection_weights),
373 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
374 n_input, aux_input_size, n_output, output_batch_leading_dim,
375 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
376 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
377 output_gate_scratch, output_ptr_time, logger,
378 intermediate_tensor_indexes, error_reporter);
379 }
380 } else {
381 for (int b = 0; b < n_batch; b++) {
382 const int input_step = n_input;
383 const int output_step = output_batch_leading_dim;
384 for (int t = 0; t < max_time; t++) {
385 // If this is the forward_sequence, step forward, otherwise step
386 // backwards.
387 const int t_rel = forward_sequence ? t : max_time - t - 1;
388 const int time_offset = b * max_time + t_rel;
389 const float* input_ptr =
390 GetTensorData<float>(input) + time_offset * input_step;
391 const float* aux_input_ptr = nullptr;
392 if (aux_input) {
393 aux_input_ptr =
394 GetTensorData<float>(aux_input) + time_offset * input_step;
395 }
396 float* output_ptr = GetTensorData<float>(output) +
397 time_offset * output_step + output_offset;
398
399 // Offset the {output,cell}_state pointers to the right batch.
400 float* output_state_ptr =
401 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
402 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
403 // Offset the scratch pointers to the right batch.
404 float* input_gate_scratch_ptr =
405 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
406 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
407 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
408 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
409
410 LstmStepCalibration(
411 input_ptr, GetTensorData<float>(input_to_input_weights),
412 GetTensorData<float>(input_to_forget_weights),
413 GetTensorData<float>(input_to_cell_weights),
414 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
415 GetTensorData<float>(aux_input_to_input_weights),
416 GetTensorData<float>(aux_input_to_forget_weights),
417 GetTensorData<float>(aux_input_to_cell_weights),
418 GetTensorData<float>(aux_input_to_output_weights),
419 GetTensorData<float>(recurrent_to_input_weights),
420 GetTensorData<float>(recurrent_to_forget_weights),
421 GetTensorData<float>(recurrent_to_cell_weights),
422 GetTensorData<float>(recurrent_to_output_weights),
423 GetTensorData<float>(cell_to_input_weights),
424 GetTensorData<float>(cell_to_forget_weights),
425 GetTensorData<float>(cell_to_output_weights),
426 GetTensorData<float>(input_layer_norm_coefficients),
427 GetTensorData<float>(forget_layer_norm_coefficients),
428 GetTensorData<float>(cell_layer_norm_coefficients),
429 GetTensorData<float>(output_layer_norm_coefficients),
430 GetTensorData<float>(input_gate_bias),
431 GetTensorData<float>(forget_gate_bias),
432 GetTensorData<float>(cell_gate_bias),
433 GetTensorData<float>(output_gate_bias),
434 GetTensorData<float>(projection_weights),
435 GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
436 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
437 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
438 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
439 output_gate_scratch_ptr, output_ptr, logger,
440 intermediate_tensor_indexes, error_reporter);
441 }
442 }
443 }
444 return kTfLiteOk;
445 }
446
447 struct OpData {
448 // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
449 // inputs).
450 // Please note the 20-input full kernel is deprecated and only kept
451 // here for backward compatibility.
452 TfLiteLSTMKernelType kernel_type;
453
454 // If the lstm is layer norm.
455 bool use_layer_norm;
456
457 // These fields are only used by full kernel.
458 int scratch_tensor_index;
459 };
460
461 // Resize the output, state tensors based on the sizes of the input tensors.
462 // Allocate a temporary scratch tensor. Also check that the sizes of the input
463 // tensors match each other.
lstm_eval(TfLiteContext * context,TfLiteNode * node,LSTMType lstm_type,Logger * logger,ErrorReporter * error_reporter)464 TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
465 LSTMType lstm_type, Logger* logger,
466 ErrorReporter* error_reporter) {
467 const TfLiteTensor* input;
468 TF_LITE_ENSURE_OK(
469 context, GetInputSafe(context, node,
470 ops::builtin::lstm::full::kInputTensor, &input));
471
472 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
473 context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
474 const TfLiteTensor* input_to_forget_weights;
475 TF_LITE_ENSURE_OK(
476 context,
477 GetInputSafe(context, node,
478 ops::builtin::lstm::full::kInputToForgetWeightsTensor,
479 &input_to_forget_weights));
480 const TfLiteTensor* input_to_cell_weights;
481 TF_LITE_ENSURE_OK(
482 context, GetInputSafe(context, node,
483 ops::builtin::lstm::full::kInputToCellWeightsTensor,
484 &input_to_cell_weights));
485 const TfLiteTensor* input_to_output_weights;
486 TF_LITE_ENSURE_OK(
487 context,
488 GetInputSafe(context, node,
489 ops::builtin::lstm::full::kInputToOutputWeightsTensor,
490 &input_to_output_weights));
491
492 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
493 context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
494 const TfLiteTensor* recurrent_to_forget_weights;
495 TF_LITE_ENSURE_OK(
496 context,
497 GetInputSafe(context, node,
498 ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
499 &recurrent_to_forget_weights));
500 const TfLiteTensor* recurrent_to_cell_weights;
501 TF_LITE_ENSURE_OK(
502 context,
503 GetInputSafe(context, node,
504 ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
505 &recurrent_to_cell_weights));
506 const TfLiteTensor* recurrent_to_output_weights;
507 TF_LITE_ENSURE_OK(
508 context,
509 GetInputSafe(context, node,
510 ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
511 &recurrent_to_output_weights));
512
513 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
514 context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
515 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
516 context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
517 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
518 context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
519
520 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
521 context, node,
522 ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
523 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
524 context, node,
525 ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
526 const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
527 context, node,
528 ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
529 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
530 context, node,
531 ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
532
533 const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
534 context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
535 const TfLiteTensor* forget_gate_bias;
536 TF_LITE_ENSURE_OK(
537 context, GetInputSafe(context, node,
538 ops::builtin::lstm::full::kForgetGateBiasTensor,
539 &forget_gate_bias));
540 const TfLiteTensor* cell_gate_bias;
541 TF_LITE_ENSURE_OK(
542 context,
543 GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
544 &cell_gate_bias));
545 const TfLiteTensor* output_gate_bias;
546 TF_LITE_ENSURE_OK(
547 context, GetInputSafe(context, node,
548 ops::builtin::lstm::full::kOutputGateBiasTensor,
549 &output_gate_bias));
550
551 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
552 context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
553 const TfLiteTensor* projection_bias = GetOptionalInputTensor(
554 context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
555
556 // Index the scratch buffers pointers to the global scratch buffer.
557 TfLiteTensor* scratch_buffer;
558 TF_LITE_ENSURE_OK(
559 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
560
561 TfLiteTensor* output_state = GetVariableInput(
562 context, node, ops::builtin::lstm::full::kOutputStateTensor);
563 TF_LITE_ENSURE(context, output_state != nullptr);
564 TfLiteTensor* cell_state = GetVariableInput(
565 context, node, ops::builtin::lstm::full::kCellStateTensor);
566 TF_LITE_ENSURE(context, cell_state != nullptr);
567
568 TfLiteTensor* output;
569 TF_LITE_ENSURE_OK(
570 context, GetOutputSafe(context, node,
571 ops::builtin::lstm::full::kOutputTensor, &output));
572
573 std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
574 // LSTM expect 5 intermediate tensors.
575 TF_LITE_ENSURE_EQ(context, node->intermediates->size, 5);
576 for (int i = 0; i < node->intermediates->size; ++i) {
577 intermediate_tensor_indexes[i] = node->intermediates->data[i];
578 }
579
580 TfLiteLSTMParams lstm_params;
581 bool time_major = true;
582 switch (lstm_type) {
583 case LSTMType::kLSTM: {
584 lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
585 time_major = true;
586 break;
587 }
588 case LSTMType::kUnidirectionalSequenceLSTM: {
589 const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
590 node->builtin_data);
591 // Copy out the LSTM specific params so they can be passed in the
592 // function.
593 lstm_params.activation = params->activation;
594 lstm_params.cell_clip = params->cell_clip;
595 lstm_params.proj_clip = params->proj_clip;
596 lstm_params.asymmetric_quantize_inputs =
597 params->asymmetric_quantize_inputs;
598 time_major = params->time_major;
599 break;
600 }
601 default:
602 return kTfLiteError;
603 }
604
605 switch (input_to_output_weights->type) {
606 case kTfLiteFloat32: {
607 return EvalCalibration(
608 input, input_to_input_weights, input_to_forget_weights,
609 input_to_cell_weights, input_to_output_weights,
610 recurrent_to_input_weights, recurrent_to_forget_weights,
611 recurrent_to_cell_weights, recurrent_to_output_weights,
612 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
613 input_layer_norm_coefficients, forget_layer_norm_coefficients,
614 cell_layer_norm_coefficients, output_layer_norm_coefficients,
615 /*aux_input=*/nullptr,
616 /*aux_input_to_input_weights=*/nullptr,
617 /*aux_input_to_forget_weights=*/nullptr,
618 /*aux_input_to_cell_weights=*/nullptr,
619 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
620 forget_gate_bias, cell_gate_bias, output_gate_bias,
621 projection_weights, projection_bias, &lstm_params,
622 /*forward_sequence=*/true,
623 /*time_major=*/time_major,
624 /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
625 logger, intermediate_tensor_indexes, error_reporter);
626 }
627 case kTfLiteUInt8:
628 case kTfLiteInt8:
629 default:
630 printf("Error. Only float model can be calibrated\n");
631 return kTfLiteError;
632 }
633 return kTfLiteOk;
634 }
635 } // namespace
636
lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)637 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
638 Logger* logger,
639 ErrorReporter* error_reporter) {
640 return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
641 }
642
unidirectional_sequence_lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)643 TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
644 TfLiteContext* context, TfLiteNode* node, Logger* logger,
645 ErrorReporter* error_reporter) {
646 return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
647 error_reporter);
648 }
649
650 } // namespace builtin
651 } // namespace calibration
652 } // namespace optimize
653 } // namespace tflite
654