1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace custom {
37
38 namespace {
39
LstmStepWithAuxInput(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * output_ptr,Logger * logger,const std::vector<int> & intemediate_tensor_indexes,ErrorReporter * error_reporter)40 inline void LstmStepWithAuxInput(
41 const float* input_ptr, const float* input_to_input_weights_ptr,
42 const float* input_to_forget_weights_ptr,
43 const float* input_to_cell_weights_ptr,
44 const float* input_to_output_weights_ptr, const float* aux_input_ptr,
45 const float* aux_input_to_input_weights_ptr,
46 const float* aux_input_to_forget_weights_ptr,
47 const float* aux_input_to_cell_weights_ptr,
48 const float* aux_input_to_output_weights_ptr,
49 const float* recurrent_to_input_weights_ptr,
50 const float* recurrent_to_forget_weights_ptr,
51 const float* recurrent_to_cell_weights_ptr,
52 const float* recurrent_to_output_weights_ptr,
53 const float* cell_to_input_weights_ptr,
54 const float* cell_to_forget_weights_ptr,
55 const float* cell_to_output_weights_ptr,
56 const float* input_layer_norm_coefficients_ptr,
57 const float* forget_layer_norm_coefficients_ptr,
58 const float* cell_layer_norm_coefficients_ptr,
59 const float* output_layer_norm_coefficients_ptr,
60 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
61 const float* cell_bias_ptr, const float* output_gate_bias_ptr,
62 const float* projection_weights_ptr, const float* projection_bias_ptr,
63 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
64 int n_aux_input, int n_output, int output_batch_leading_dim,
65 float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
66 float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
67 float* output_ptr, Logger* logger,
68 const std::vector<int>& intemediate_tensor_indexes,
69 ErrorReporter* error_reporter) {
70 // Since we have already checked that weights are all there or none, we can
71 // check the existence of only one to the get the condition.
72 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
73 const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
74 const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
75
76 // Initialize scratch buffers with bias for regular lstm or initialize with
77 // zero for layer norm lstm.
78 if (use_layer_norm) {
79 if (!use_cifg) {
80 std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
81 }
82 std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
83 std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
84 std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
85 } else {
86 if (!use_cifg) {
87 tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
88 n_batch, input_gate_scratch);
89 }
90 tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
91 forget_gate_scratch);
92 tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
93 cell_scratch);
94 tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
95 output_gate_scratch);
96 }
97
98 // For each batch and cell: compute input_weight * input.
99 if (!use_cifg) {
100 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
101 input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch,
102 input_gate_scratch);
103 }
104
105 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
106 input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch,
107 forget_gate_scratch);
108 tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
109 n_cell, n_input, input_ptr,
110 n_batch, cell_scratch);
111 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
112 input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
113 output_gate_scratch);
114
115 {
116 // calibration.
117 if (!use_cifg) {
118 logger->LogTensorValue(intemediate_tensor_indexes[1], input_gate_scratch,
119 n_cell * n_batch, error_reporter);
120 }
121 logger->LogTensorValue(intemediate_tensor_indexes[4], forget_gate_scratch,
122 n_cell * n_batch, error_reporter);
123 logger->LogTensorValue(intemediate_tensor_indexes[7], cell_scratch,
124 n_cell * n_batch, error_reporter);
125 logger->LogTensorValue(intemediate_tensor_indexes[10], output_gate_scratch,
126 n_cell * n_batch, error_reporter);
127 }
128
129 // If auxiliary input is available then compute aux_input_weight * aux_input
130 if (aux_input_ptr != nullptr) {
131 if (!use_cifg) {
132 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
133 aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
134 n_batch, input_gate_scratch);
135 }
136
137 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
138 aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
139 n_batch, forget_gate_scratch);
140 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
141 aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
142 n_batch, cell_scratch);
143 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
144 aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
145 n_batch, output_gate_scratch);
146 }
147
148 // For each batch and cell: compute recurrent_weight * output_state.
149 if (!use_cifg) {
150 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
151 recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
152 n_batch, input_gate_scratch);
153 }
154 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
155 recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
156 n_batch, forget_gate_scratch);
157 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
158 recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
159 n_batch, cell_scratch);
160 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
161 recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
162 n_batch, output_gate_scratch);
163 {
164 // calibrition.
165 if (!use_cifg) {
166 std::vector<float> temp_input(n_batch * n_cell);
167 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
168 recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
169 n_batch, temp_input.data());
170 logger->LogTensorValue(intemediate_tensor_indexes[2], temp_input.data(),
171 n_cell * n_batch, error_reporter);
172 }
173 std::vector<float> temp_forget(n_batch * n_cell);
174 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
175 recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
176 n_batch, temp_forget.data());
177 logger->LogTensorValue(intemediate_tensor_indexes[5], temp_forget.data(),
178 n_cell * n_batch, error_reporter);
179
180 std::vector<float> temp_cell(n_batch * n_cell);
181 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
182 recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
183 n_batch, temp_cell.data());
184
185 logger->LogTensorValue(intemediate_tensor_indexes[8], temp_cell.data(),
186 n_cell * n_batch, error_reporter);
187
188 std::vector<float> temp_output(n_batch * n_cell);
189 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
190 recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
191 n_batch, temp_output.data());
192 logger->LogTensorValue(intemediate_tensor_indexes[11], temp_output.data(),
193 n_cell * n_batch, error_reporter);
194 }
195
196 // For each batch and cell: update input gate.
197 if (!use_cifg) {
198 if (use_peephole) {
199 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
200 cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
201 input_gate_scratch);
202 }
203 if (use_layer_norm) {
204 logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
205 n_cell * n_batch, error_reporter);
206 tensor_utils::MeanStddevNormalization(
207 input_gate_scratch, input_gate_scratch, n_cell, n_batch);
208 tensor_utils::VectorBatchVectorCwiseProduct(
209 input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
210 n_batch, input_gate_scratch);
211 tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
212 input_gate_scratch);
213 }
214 tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
215 input_gate_scratch);
216 }
217
218 // For each batch and cell: update forget gate.
219 if (use_peephole) {
220 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
221 cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
222 forget_gate_scratch);
223 }
224 if (use_layer_norm) {
225 logger->LogTensorValue(intemediate_tensor_indexes[3], forget_gate_scratch,
226 n_cell * n_batch, error_reporter);
227 tensor_utils::MeanStddevNormalization(forget_gate_scratch,
228 forget_gate_scratch, n_cell, n_batch);
229 tensor_utils::VectorBatchVectorCwiseProduct(
230 forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
231 n_batch, forget_gate_scratch);
232 tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
233 forget_gate_scratch);
234 }
235 tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
236 forget_gate_scratch);
237
238 // For each batch and cell: update the cell.
239 tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
240 n_batch * n_cell, cell_state_ptr);
241 if (use_layer_norm) {
242 logger->LogTensorValue(intemediate_tensor_indexes[6], cell_scratch,
243 n_cell * n_batch, error_reporter);
244 tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
245 n_batch);
246 tensor_utils::VectorBatchVectorCwiseProduct(
247 cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
248 cell_scratch);
249 tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
250 cell_scratch);
251 }
252 tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
253 params->activation, cell_scratch);
254 if (use_cifg) {
255 tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
256 forget_gate_scratch);
257 tensor_utils::VectorVectorCwiseProductAccumulate(
258 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
259 } else {
260 tensor_utils::VectorVectorCwiseProductAccumulate(
261 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
262 }
263 if (params->cell_clip > 0.0) {
264 tensor_utils::CwiseClipping(cell_state_ptr, n_batch * n_cell,
265 params->cell_clip);
266 }
267
268 // For each batch and cell: update the output gate.
269 if (use_peephole) {
270 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
271 cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
272 output_gate_scratch);
273 }
274 if (use_layer_norm) {
275 logger->LogTensorValue(intemediate_tensor_indexes[9], output_gate_scratch,
276 n_cell * n_batch, error_reporter);
277 tensor_utils::MeanStddevNormalization(output_gate_scratch,
278 output_gate_scratch, n_cell, n_batch);
279 tensor_utils::VectorBatchVectorCwiseProduct(
280 output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
281 n_batch, output_gate_scratch);
282 tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
283 output_gate_scratch);
284 }
285 tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
286 output_gate_scratch);
287 tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
288 params->activation, cell_scratch);
289 tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
290 n_batch * n_cell, output_gate_scratch);
291
292 const bool use_projection_weight = (projection_weights_ptr != nullptr);
293 const bool use_projection_bias = (projection_bias_ptr != nullptr);
294
295 // For each batch: update the projection and output_state. Note that since
296 // the output batch rows may not be contiguous (output_batch_leading_dim !=
297 // n_output), we unroll batched operations.
298 if (use_projection_weight) {
299 if (use_projection_bias) {
300 for (int k = 0; k < n_batch; k++) {
301 std::copy_n(projection_bias_ptr, n_output,
302 output_ptr + k * output_batch_leading_dim);
303 }
304 } else {
305 for (int k = 0; k < n_batch; k++) {
306 std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
307 }
308 }
309 for (int k = 0; k < n_batch; k++) {
310 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
311 projection_weights_ptr, n_output, n_cell,
312 output_gate_scratch + k * n_cell,
313 /*n_batch=*/1, output_ptr + k * output_batch_leading_dim);
314 if (params->proj_clip > 0.0) {
315 tensor_utils::CwiseClipping(output_ptr + k * output_batch_leading_dim,
316 n_output, params->proj_clip);
317 }
318 }
319 } else {
320 for (int k = 0; k < n_batch; k++) {
321 std::copy_n(output_gate_scratch + k * n_output, n_output,
322 output_ptr + k * output_batch_leading_dim);
323 }
324 }
325 for (int k = 0; k < n_batch; k++) {
326 std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
327 output_state_ptr + k * n_output);
328 }
329 }
330
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * activation_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intemediate_tensor_indexes,ErrorReporter * error_reporter)331 TfLiteStatus EvalFloat(
332 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
333 const TfLiteTensor* input_to_forget_weights,
334 const TfLiteTensor* input_to_cell_weights,
335 const TfLiteTensor* input_to_output_weights,
336 const TfLiteTensor* recurrent_to_input_weights,
337 const TfLiteTensor* recurrent_to_forget_weights,
338 const TfLiteTensor* recurrent_to_cell_weights,
339 const TfLiteTensor* recurrent_to_output_weights,
340 const TfLiteTensor* cell_to_input_weights,
341 const TfLiteTensor* cell_to_forget_weights,
342 const TfLiteTensor* cell_to_output_weights,
343 const TfLiteTensor* input_layer_norm_coefficients,
344 const TfLiteTensor* forget_layer_norm_coefficients,
345 const TfLiteTensor* cell_layer_norm_coefficients,
346 const TfLiteTensor* output_layer_norm_coefficients,
347 const TfLiteTensor* aux_input,
348 const TfLiteTensor* aux_input_to_input_weights,
349 const TfLiteTensor* aux_input_to_forget_weights,
350 const TfLiteTensor* aux_input_to_cell_weights,
351 const TfLiteTensor* aux_input_to_output_weights,
352 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
353 const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
354 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
355 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
356 int output_offset, TfLiteTensor* scratch_buffer,
357 TfLiteTensor* activation_state, TfLiteTensor* cell_state,
358 TfLiteTensor* output, Logger* logger,
359 const std::vector<int>& intemediate_tensor_indexes,
360 ErrorReporter* error_reporter) {
361 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
362 int max_time, n_batch;
363 if (input->dims->size == 3) {
364 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
365 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
366 } else {
367 max_time = 1;
368 n_batch = input->dims->data[0];
369 }
370 const int n_input = input->dims->data[input->dims->size - 1];
371 const int aux_input_size =
372 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
373
374 // n_cell and n_output will be the same size when there is no projection.
375 const int n_cell = input_to_output_weights->dims->data[0];
376 const int n_output = recurrent_to_output_weights->dims->data[1];
377
378 // Since we have already checked that weights are all there or none, we can
379 // check the existence of only one to the get the condition.
380 const bool use_cifg = (input_to_input_weights == nullptr);
381
382 // Index the scratch buffers pointers to the global scratch buffer.
383 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
384 float* input_gate_scratch = nullptr;
385 float* cell_scratch = nullptr;
386 float* forget_gate_scratch = nullptr;
387 float* output_gate_scratch = nullptr;
388 if (use_cifg) {
389 cell_scratch = scratch_buffer_ptr;
390 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
391 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
392 } else {
393 input_gate_scratch = scratch_buffer_ptr;
394 cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
395 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
396 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
397 }
398
399 const int output_batch_leading_dim =
400 output->dims->data[output->dims->size - 1];
401 if (time_major) {
402 // Loop through the sequence.
403 const int input_step = n_batch * n_input;
404 const int output_step = n_batch * output_batch_leading_dim;
405 for (int t = 0; t < max_time; t++) {
406 // If this is the forward_sequence, step forward, otherwise step
407 // backwards.
408 const int t_rel = forward_sequence ? t : max_time - t - 1;
409 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
410 const float* aux_input_ptr = nullptr;
411 if (aux_input) {
412 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
413 }
414 float* output_ptr_time =
415 GetTensorData<float>(output) + t_rel * output_step + output_offset;
416
417 LstmStepWithAuxInput(
418 input_ptr, GetTensorData<float>(input_to_input_weights),
419 GetTensorData<float>(input_to_forget_weights),
420 GetTensorData<float>(input_to_cell_weights),
421 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
422 GetTensorData<float>(aux_input_to_input_weights),
423 GetTensorData<float>(aux_input_to_forget_weights),
424 GetTensorData<float>(aux_input_to_cell_weights),
425 GetTensorData<float>(aux_input_to_output_weights),
426 GetTensorData<float>(recurrent_to_input_weights),
427 GetTensorData<float>(recurrent_to_forget_weights),
428 GetTensorData<float>(recurrent_to_cell_weights),
429 GetTensorData<float>(recurrent_to_output_weights),
430 GetTensorData<float>(cell_to_input_weights),
431 GetTensorData<float>(cell_to_forget_weights),
432 GetTensorData<float>(cell_to_output_weights),
433 GetTensorData<float>(input_layer_norm_coefficients),
434 GetTensorData<float>(forget_layer_norm_coefficients),
435 GetTensorData<float>(cell_layer_norm_coefficients),
436 GetTensorData<float>(output_layer_norm_coefficients),
437 GetTensorData<float>(input_gate_bias),
438 GetTensorData<float>(forget_gate_bias),
439 GetTensorData<float>(cell_bias),
440 GetTensorData<float>(output_gate_bias),
441 GetTensorData<float>(projection_weights),
442 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
443 n_input, aux_input_size, n_output, output_batch_leading_dim,
444 GetTensorData<float>(activation_state),
445 GetTensorData<float>(cell_state), input_gate_scratch,
446 forget_gate_scratch, cell_scratch, output_gate_scratch,
447 output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
448 }
449 } else {
450 for (int b = 0; b < n_batch; b++) {
451 const int input_step = n_input;
452 const int output_step = output_batch_leading_dim;
453 for (int t = 0; t < max_time; t++) {
454 // If this is the forward_sequence, step forward, otherwise step
455 // backwards.
456 const int t_rel = forward_sequence ? t : max_time - t - 1;
457 const int time_offset = b * max_time + t_rel;
458 const float* input_ptr =
459 GetTensorData<float>(input) + time_offset * input_step;
460 const float* aux_input_ptr = nullptr;
461 if (aux_input) {
462 aux_input_ptr =
463 GetTensorData<float>(aux_input) + time_offset * input_step;
464 }
465 float* output_ptr = GetTensorData<float>(output) +
466 time_offset * output_step + output_offset;
467
468 // Offset the {activation,cell}_state pointers to the right batch.
469 float* activation_state_ptr = GetTensorData<float>(activation_state) +
470 b * output_batch_leading_dim;
471 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
472 // Offset the scratch pointers to the right batch.
473 float* input_gate_scratch_ptr =
474 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
475 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
476 float* cell_scratch_ptr = cell_scratch + b * n_cell;
477 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
478
479 LstmStepWithAuxInput(
480 input_ptr, GetTensorData<float>(input_to_input_weights),
481 GetTensorData<float>(input_to_forget_weights),
482 GetTensorData<float>(input_to_cell_weights),
483 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
484 GetTensorData<float>(aux_input_to_input_weights),
485 GetTensorData<float>(aux_input_to_forget_weights),
486 GetTensorData<float>(aux_input_to_cell_weights),
487 GetTensorData<float>(aux_input_to_output_weights),
488 GetTensorData<float>(recurrent_to_input_weights),
489 GetTensorData<float>(recurrent_to_forget_weights),
490 GetTensorData<float>(recurrent_to_cell_weights),
491 GetTensorData<float>(recurrent_to_output_weights),
492 GetTensorData<float>(cell_to_input_weights),
493 GetTensorData<float>(cell_to_forget_weights),
494 GetTensorData<float>(cell_to_output_weights),
495 GetTensorData<float>(input_layer_norm_coefficients),
496 GetTensorData<float>(forget_layer_norm_coefficients),
497 GetTensorData<float>(cell_layer_norm_coefficients),
498 GetTensorData<float>(output_layer_norm_coefficients),
499 GetTensorData<float>(input_gate_bias),
500 GetTensorData<float>(forget_gate_bias),
501 GetTensorData<float>(cell_bias),
502 GetTensorData<float>(output_gate_bias),
503 GetTensorData<float>(projection_weights),
504 GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
505 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
506 activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
507 forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
508 output_ptr, logger, intemediate_tensor_indexes, error_reporter);
509 }
510 }
511 }
512 return kTfLiteOk;
513 }
514
515 struct OpData {
516 // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
517 // inputs).
518 // Please note the 20-input full kernel is deprecated and only kept
519 // here for backward compatibility.
520 TfLiteLSTMKernelType kernel_type;
521
522 // If the lstm is layer norm.
523 bool use_layer_norm;
524
525 // These fields are only used by full kernel.
526 int scratch_tensor_index;
527 };
528
529 // Resize the output, state tensors based on the sizes of the input tensors.
530 // Allocate a temporary scratch tensor. Also check that the sizes of the input
531 // tensors match each other.
lstm_eval(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)532 TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
533 ErrorReporter* error_reporter) {
534 const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
535
536 const TfLiteTensor* input =
537 GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
538
539 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
540 context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
541 const TfLiteTensor* input_to_forget_weights = GetInput(
542 context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
543 const TfLiteTensor* input_to_cell_weights = GetInput(
544 context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
545 const TfLiteTensor* input_to_output_weights = GetInput(
546 context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
547
548 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
549 context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
550 const TfLiteTensor* recurrent_to_forget_weights = GetInput(
551 context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
552 const TfLiteTensor* recurrent_to_cell_weights = GetInput(
553 context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
554 const TfLiteTensor* recurrent_to_output_weights = GetInput(
555 context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
556
557 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
558 context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
559 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
560 context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
561 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
562 context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
563
564 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
565 context, node,
566 ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
567 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
568 context, node,
569 ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
570 const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
571 context, node,
572 ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
573 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
574 context, node,
575 ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
576
577 const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
578 context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
579 const TfLiteTensor* forget_gate_bias =
580 GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
581 const TfLiteTensor* cell_bias =
582 GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
583 const TfLiteTensor* output_gate_bias =
584 GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
585
586 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
587 context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
588 const TfLiteTensor* projection_bias = GetOptionalInputTensor(
589 context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
590
591 // Index the scratch buffers pointers to the global scratch buffer.
592 TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
593
594 TfLiteTensor* activation_state = GetVariableInput(
595 context, node, ops::builtin::lstm::full::kOutputStateTensor);
596 TF_LITE_ENSURE(context, activation_state != nullptr);
597 TfLiteTensor* cell_state = GetVariableInput(
598 context, node, ops::builtin::lstm::full::kCellStateTensor);
599 TF_LITE_ENSURE(context, cell_state != nullptr);
600
601 TfLiteTensor* output =
602 GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
603
604 std::vector<int> intemediate_tensor_indexes(node->intermediates->size);
605 for (int i = 0; i < node->intermediates->size; ++i) {
606 intemediate_tensor_indexes[i] = node->intermediates->data[i];
607 }
608
609 switch (input_to_output_weights->type) {
610 case kTfLiteFloat32: {
611 return EvalFloat(
612 input, input_to_input_weights, input_to_forget_weights,
613 input_to_cell_weights, input_to_output_weights,
614 recurrent_to_input_weights, recurrent_to_forget_weights,
615 recurrent_to_cell_weights, recurrent_to_output_weights,
616 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
617 input_layer_norm_coefficients, forget_layer_norm_coefficients,
618 cell_layer_norm_coefficients, output_layer_norm_coefficients,
619 /*aux_input=*/nullptr,
620 /*aux_input_to_input_weights=*/nullptr,
621 /*aux_input_to_forget_weights=*/nullptr,
622 /*aux_input_to_cell_weights=*/nullptr,
623 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
624 forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
625 projection_bias, params, /*forward_sequence=*/true,
626 /*time_major=*/true,
627 /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
628 output, logger, intemediate_tensor_indexes, error_reporter);
629 }
630 case kTfLiteUInt8:
631 case kTfLiteInt8:
632 default:
633 printf("Error. Only float model can be calibrated\n");
634 return kTfLiteError;
635 }
636 return kTfLiteOk;
637 }
638 } // namespace
639
lstm_logging_kernel(TfLiteContext * context,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)640 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
641 Logger* logger,
642 ErrorReporter* error_reporter) {
643 return lstm_eval(context, node, logger, error_reporter);
644 }
645
646 } // namespace custom
647 } // namespace calibration
648 } // namespace optimize
649 } // namespace tflite
650