1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16
17 #include <math.h>
18 #include <string.h>
19
20 #include <algorithm>
21 #include <cstdint>
22 #include <memory>
23 #include <vector>
24
25 #include "ruy/profiler/instrumentation.h" // from @ruy
26 #include "tensorflow/lite/c/builtin_op_data.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/kernels/cpu_backend_context.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
33 #include "tensorflow/lite/kernels/op_macros.h"
34
35 namespace tflite {
36 namespace ops {
37 namespace builtin {
38 namespace lstm_eval {
39 namespace {
40
ComputeRowSums(int32_t * input_to_input_row_sums,int32_t * input_to_forget_row_sums,int32_t * input_to_cell_row_sums,int32_t * input_to_output_row_sums,int32_t * aux_input_to_input_row_sums,int32_t * aux_input_to_forget_row_sums,int32_t * aux_input_to_cell_row_sums,int32_t * aux_input_to_output_row_sums,int32_t * recurrent_to_input_row_sums,int32_t * recurrent_to_forget_row_sums,int32_t * recurrent_to_cell_row_sums,int32_t * recurrent_to_output_row_sums,int32_t * projection_weights_row_sums,int32_t * row_sums,int n_cell,int n_input,int n_aux_input,int n_output,const int8_t * input_to_input_weights_ptr,const int8_t * input_to_forget_weights_ptr,const int8_t * input_to_cell_weights_ptr,const int8_t * input_to_output_weights_ptr,const int8_t * aux_input_to_input_weights_ptr,const int8_t * aux_input_to_forget_weights_ptr,const int8_t * aux_input_to_cell_weights_ptr,const int8_t * aux_input_to_output_weights_ptr,const int8_t * recurrent_to_input_weights_ptr,const int8_t * recurrent_to_forget_weights_ptr,const int8_t * recurrent_to_cell_weights_ptr,const int8_t * recurrent_to_output_weights_ptr,const int8_t * projection_weights_ptr,bool use_cifg,const float * aux_input_ptr)41 void ComputeRowSums(
42 int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
43 int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
44 int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
45 int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
46 int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
47 int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
48 int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
49 int n_input, int n_aux_input, int n_output,
50 const int8_t* input_to_input_weights_ptr,
51 const int8_t* input_to_forget_weights_ptr,
52 const int8_t* input_to_cell_weights_ptr,
53 const int8_t* input_to_output_weights_ptr,
54 const int8_t* aux_input_to_input_weights_ptr,
55 const int8_t* aux_input_to_forget_weights_ptr,
56 const int8_t* aux_input_to_cell_weights_ptr,
57 const int8_t* aux_input_to_output_weights_ptr,
58 const int8_t* recurrent_to_input_weights_ptr,
59 const int8_t* recurrent_to_forget_weights_ptr,
60 const int8_t* recurrent_to_cell_weights_ptr,
61 const int8_t* recurrent_to_output_weights_ptr,
62 const int8_t* projection_weights_ptr, bool use_cifg,
63 const float* aux_input_ptr) {
64 // Compute the row sums for dequantization
65 if (!use_cifg) {
66 tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
67 input_to_input_row_sums, n_cell, n_input);
68 }
69 tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
70 input_to_forget_row_sums, n_cell, n_input);
71 tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
72 input_to_cell_row_sums, n_cell, n_input);
73 tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
74 input_to_output_row_sums, n_cell, n_input);
75
76 if (aux_input_ptr) {
77 if (!use_cifg) {
78 tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
79 aux_input_to_input_row_sums, n_cell,
80 n_aux_input);
81 }
82 tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
83 aux_input_to_forget_row_sums, n_cell,
84 n_aux_input);
85 tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
86 aux_input_to_cell_row_sums, n_cell,
87 n_aux_input);
88 tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
89 aux_input_to_output_row_sums, n_cell,
90 n_aux_input);
91 }
92 if (!use_cifg) {
93 tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
94 recurrent_to_input_row_sums, n_cell,
95 n_output);
96 }
97 tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
98 recurrent_to_forget_row_sums, n_cell,
99 n_output);
100 tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
101 recurrent_to_cell_row_sums, n_cell,
102 n_output);
103 tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
104 recurrent_to_output_row_sums, n_cell,
105 n_output);
106
107 if (projection_weights_ptr != nullptr) {
108 tensor_utils::ReductionSumVector(
109 projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
110 }
111 }
112
GetTensorScale(const TfLiteTensor * tensor)113 inline float GetTensorScale(const TfLiteTensor* tensor) {
114 return tensor == nullptr ? 1.0f : tensor->params.scale;
115 }
116
117 // LINT.IfChange
118 // Calculates a single LSTM gate.
119 //
120 // Implements the following formula: (* is matrix multiply)
121 // gate = activate(W_input * input + W_aux * aux_input +
122 // W_peephole * cell + W_recurrent * prev_output + bias)
123 // with layer norm:
124 // gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
125 //
126 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
127 //
128 // Parameters:
129 // Input vectors (to LSTM): | Size: | Optional?
130 // input | n_input |
131 // aux_input | n_aux_input | y (bidir LSTM)
132 // Input vectors (persistent states):
133 // output_state | n_output |
134 // cell_state | n_cell |
135 // 'Constant' inputs:
136 // input_to_gate_weights | n_cell * n_input |
137 // aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
138 // recurrent_to_gate_weights | n_cell * n_output |
139 // cell_to_gate_weights | n_cell | y (peephole)
140 // gate_bias | n_cell |
141 // layer_norm_coefficients | n_cell | y (layer norm)
142 // Output vector:
143 // gate | n_cell |
144 // Scalar parameters:
145 // n_batch - batch size / number of vectors
146 // n_input, n_aux_input, n_output, n_cell - size of vectors.
147 // activation - activation to use.
148 // is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
149 // use_layer_norm - if doing layer norm LSTM.
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros)150 inline void CalculateLstmGateFloat(
151 const float* input, const float* input_to_gate_weights,
152 const float* aux_input, const float* aux_input_to_gate_weights,
153 const float* output_state, const float* recurrent_to_gate_weights,
154 const float* cell_state, const float* cell_to_gate_weights,
155 const float* layer_norm_coefficients, const float* gate_bias,
156 const int n_batch, const int n_input, const int n_aux_input,
157 const int n_output, const int n_cell,
158 const TfLiteFusedActivation activation, float* gate,
159 const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
160 const bool use_peephole = (cell_to_gate_weights != nullptr);
161 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
162
163 // Initialize scratch buffers with bias for regular lstm or initialize with
164 // zero for layer norm lstm.
165 if (use_layer_norm) {
166 std::fill_n(gate, n_cell * n_batch, 0.0f);
167 } else {
168 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
169 }
170 // For each batch and cell: compute input_weight * input.
171 // Skip if input is all zeros.
172 if (!is_input_all_zeros) {
173 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
174 input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
175 }
176 // For each batch and cell: compute aux_input_weight * aux_input.
177 // Skip if auxiliary input is not available or all zeros.
178 if (!is_aux_input_all_zeros) {
179 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
180 n_cell, n_aux_input,
181 aux_input, n_batch, gate);
182 }
183 // For each batch and cell: compute recurrent_weight * output_state.
184 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
185 recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
186 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
187 if (use_peephole) {
188 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
189 cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
190 }
191 // Do layer normalization (if layer norm LSTM)
192 if (use_layer_norm) {
193 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
194 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
195 gate, n_batch, gate);
196 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
197 }
198 // Apply activation
199 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
200 gate);
201 }
202
203 // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
204 //
205 // Implements the following formula:
206 // cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
207 //
208 // With CIFG LSTM, input gate is replaced by (1-forget_gate).
209 //
210 // Parameters:
211 // - n_batch, n_cell: sizes of vectors
212 // - cell_state: input/output vector, size n_batch*n_cell
213 // - input_gate: input vector, size n_batch*n_cell.
214 // - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
215 // - cell_gate: input vector, size n_batch*n_cell.
216 // - use_cifg: use 1-forget_gate instead of input_gate.
217 // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)218 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
219 const float* input_gate, float* forget_gate,
220 const float* cell_gate, bool use_cifg, float clip) {
221 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
222 n_batch * n_cell, cell_state);
223
224 if (use_cifg) {
225 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
226 // scratch, as input_gate array is not allocated in this case. (Be careful
227 // not to write to the scratch before reading the forget gate data.)
228 float* scratch = forget_gate;
229 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
230 tensor_utils::VectorVectorCwiseProductAccumulate(
231 cell_gate, scratch, n_batch * n_cell, cell_state);
232 } else {
233 tensor_utils::VectorVectorCwiseProductAccumulate(
234 cell_gate, input_gate, n_batch * n_cell, cell_state);
235 }
236 if (clip > 0.0f) {
237 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
238 }
239 }
240
241 // Calculates the output state tensor of an LSTM step.
242 //
243 // Implements the following formula:
244 // output_no_projection = output_gate .* activate(cell_state)
245 // (elementwise vector product)
246 // If no projection is used:
247 // output = output_state = output_no_projection
248 // With projection:
249 // output = output_state = clip(W*output_no_projection + bias)
250 //
251 // Output might not have a different 'stride' than n_batch, so we need to copy.
252 //
253 // Parameters:
254 // - n_batch: batches: the number of distinct vectors in each array.
255 // - n_cell, n_output: sizes of vectors.
256 // - cell_state, output_gate: input vectors, size n_batch*n_cell.
257 // - projection_weights, projection_weights_scale, projection_bias:
258 // constant inputs, describing projection matrix and bias.
259 // - proj_clip: if > 0, clip the output of the projection.
260 // - output_state: output vector, size n_batch*n_output. Must be contigous.
261 // - scratch: scratch area, size n_batch*n_cell.
CalculateLstmOutputFloat(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch)262 void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
263 const float* cell_state, const float* output_gate,
264 TfLiteFusedActivation activation,
265 const float* projection_weights,
266 const float* projection_bias,
267 const float proj_clip, float* output_state,
268 float* scratch) {
269 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
270 activation, scratch);
271 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
272 scratch);
273
274 const bool use_projection = (projection_weights != nullptr);
275 const bool use_projection_bias = (projection_bias != nullptr);
276
277 if (use_projection) {
278 if (use_projection_bias) {
279 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
280 output_state);
281 } else {
282 std::fill_n(output_state, n_batch * n_output, 0.0f);
283 }
284 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
285 projection_weights, n_output, n_cell, scratch, n_batch, output_state);
286 if (proj_clip > 0.0f) {
287 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
288 }
289 } else {
290 std::copy_n(scratch, n_batch * n_output, output_state);
291 }
292 }
293 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
294 // ../experimental/kernels/fp16/lstm_eval.cc)
295
296 // Calculates a single LSTM gate, hybrid version.
297 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateHybrid(const int8_t * input,const float * input_sf,const int32_t * input_zp,const int8_t * input_to_gate_weights,const uint8_t * input_to_gate_weights_ledger,const float input_to_gate_weights_scale,int32_t * input_to_gate_row_sums,const int8_t * aux_input,const float * aux_input_sf,const int32_t * aux_input_zp,const int8_t * aux_input_to_gate_weights,const float aux_input_to_gate_weights_scale,int32_t * aux_input_to_gate_row_sums,const int8_t * output_state,const float * output_state_sf,const int32_t * output_state_zp,const int8_t * recurrent_to_gate_weights,const uint8_t * recurrent_to_gate_weights_ledger,const float recurrent_to_gate_weights_scale,int32_t * recurrent_to_gate_row_sums,const float * cell_state,const int8_t * cell_to_gate_weights,const float cell_to_gate_weights_scale,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,const bool is_output_state_all_zeros,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,float * scratch1,int32_t * accum_scratch)298 void CalculateLstmGateHybrid(
299 // Input and weights
300 const int8_t* input, const float* input_sf, const int32_t* input_zp,
301 const int8_t* input_to_gate_weights,
302 const uint8_t* input_to_gate_weights_ledger,
303 const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
304 // Aux input and weights
305 const int8_t* aux_input, const float* aux_input_sf,
306 const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
307 const float aux_input_to_gate_weights_scale,
308 int32_t* aux_input_to_gate_row_sums,
309 // Output state and weights
310 const int8_t* output_state, const float* output_state_sf,
311 const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
312 const uint8_t* recurrent_to_gate_weights_ledger,
313 const float recurrent_to_gate_weights_scale,
314 int32_t* recurrent_to_gate_row_sums,
315 // Cell state and weights (peephole LSTM)
316 const float* cell_state, const int8_t* cell_to_gate_weights,
317 const float cell_to_gate_weights_scale,
318 // Layer normalization coefficients (layer norm LSTM) + gate bias
319 const float* layer_norm_coefficients, const float* gate_bias,
320 // Array sizes
321 const int n_batch, const int n_input, const int n_aux_input,
322 const int n_output, const int n_cell,
323 const TfLiteFusedActivation activation,
324 // Output
325 float* gate,
326 // Parameters for performance optimizations
327 const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
328 const bool is_output_state_all_zeros, bool* compute_row_sums,
329 CpuBackendContext* context,
330 // Scratch arrays
331 float* scratch0, // size: n_batch
332 float* scratch1, // size: n_cell, only used if peephole LSTM
333 int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate
334 ) {
335 const bool use_peephole = (cell_to_gate_weights != nullptr);
336 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
337
338 // Initialize scratch buffers with bias for regular lstm or initialize with
339 // zero for layer norm lstm.
340 if (use_layer_norm) {
341 std::fill_n(gate, n_cell * n_batch, 0.0f);
342 } else {
343 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
344 }
345 // For each batch and cell: compute input_weight * input.
346 // Skip if input is all zeros.
347 if (!is_input_all_zeros) {
348 if (input_to_gate_weights_ledger != nullptr) {
349 std::vector<float> scales(n_batch);
350 for (int i = 0; i < n_batch; i++) {
351 scales[i] = input_to_gate_weights_scale * input_sf[i];
352 }
353 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
354 input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
355 input, scales.data(), n_batch, gate);
356
357 } else {
358 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
359 input_to_gate_weights, n_cell, n_input, input,
360 input_to_gate_weights_scale, input_sf, n_batch, gate,
361 /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
362 input_to_gate_row_sums, compute_row_sums, scratch0, context);
363 }
364 }
365 // For each batch and cell: compute aux_input_weight * aux_input.
366 // Skip if auxiliary input is not available or all zeros.
367 if (!is_aux_input_all_zeros) {
368 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
369 aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
370 aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
371 /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
372 aux_input_to_gate_row_sums, compute_row_sums, scratch0, context);
373 }
374 // For each batch and cell: compute recurrent_weight * output_state.
375 // Skip if output state is all zeros.
376 if (!is_output_state_all_zeros) {
377 if (recurrent_to_gate_weights_ledger != nullptr) {
378 std::vector<float> scales(n_batch);
379 for (int i = 0; i < n_batch; i++) {
380 scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
381 }
382 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
383 recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
384 n_output, output_state, scales.data(), n_batch, gate);
385 } else {
386 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
387 recurrent_to_gate_weights, n_cell, n_output, output_state,
388 recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
389 /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
390 recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
391 }
392 }
393 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
394 if (use_peephole) {
395 float* recovered_cell_weights = scratch1;
396 tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
397 cell_to_gate_weights_scale,
398 recovered_cell_weights);
399 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
400 recovered_cell_weights, n_cell, cell_state, n_batch, gate);
401 }
402 // Do layer normalization (if layer norm LSTM)
403 if (use_layer_norm) {
404 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
405 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
406 gate, n_batch, gate);
407 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
408 }
409 // Apply activation
410 tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation,
411 gate);
412 }
413
414 // Calculates the output state tensor of an LSTM step. See Float version too.
415 //
416 // Parameters:
417 // - n_batch: batches: the number of distinct vectors in each array.
418 // - n_cell, n_output: sizes of vectors.
419 // - cell_state, output_gate: input vectors, size n_batch*n_cell.
420 // - projection_weights, projection_weights_scale, projection_bias:
421 // constant inputs, describing projection matrix and bias.
422 // - proj_clip: if > 0, clip the output of the projection.
423 // - output_state: output vector, size n_batch*n_output. Must be contigous.
424 // - asymmetric_quantize_inputs: parameter to control quantization.
425 // - projection_weights_row_sums, compute_row_sums, context: Data for optimized
426 // MatrixBatchVectorMultiplyAccumulate.
427 // - scratch0: scratch area of size n_batch*n_cell
428 // - scratch1: scratch area of size n_batch*n_cell
429 // - scratch2: scratch area of size n_batch
430 // - scratch3: scratch area of size n_batch
431 // - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputHybrid(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const int8_t * projection_weights,const uint8_t * projection_weights_ledger,float projection_weights_scale,const float * projection_bias,const float proj_clip,float * output_state,bool asymmetric_quantize_inputs,int32_t * projection_weights_row_sums,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,int8_t * scratch1,float * scratch2,int32_t * scratch3,int32_t * scratch4)432 void CalculateLstmOutputHybrid(
433 int n_batch, int n_cell, int n_output, const float* cell_state,
434 const float* output_gate, TfLiteFusedActivation activation,
435 const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
436 float projection_weights_scale, const float* projection_bias,
437 const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
438 int32_t* projection_weights_row_sums, bool* compute_row_sums,
439 CpuBackendContext* context, float* scratch0, int8_t* scratch1,
440 float* scratch2, int32_t* scratch3, int32_t* scratch4) {
441 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
442 activation, scratch0);
443 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
444 n_batch * n_cell, scratch0);
445
446 const bool use_projection = (projection_weights != nullptr);
447 const bool use_projection_bias = (projection_bias != nullptr);
448
449 if (use_projection) {
450 if (use_projection_bias) {
451 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
452 output_state);
453 } else {
454 std::fill_n(output_state, n_batch * n_output, 0.0f);
455 }
456 if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
457 // Save quantization and matmul computation for all zero output.
458 tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
459 scratch2, scratch3,
460 asymmetric_quantize_inputs);
461 if (projection_weights_ledger != nullptr) {
462 std::vector<float> scales(n_batch);
463 for (int i = 0; i < n_batch; i++) {
464 scales[i] = projection_weights_scale * scratch2[i];
465 }
466 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
467 projection_weights, projection_weights_ledger, n_output, n_cell,
468 scratch1, scales.data(), n_batch, output_state);
469 } else {
470 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
471 projection_weights, n_output, n_cell, scratch1,
472 projection_weights_scale, scratch2, n_batch, output_state,
473 /*per_channel_scale=*/nullptr, scratch3, scratch4,
474 projection_weights_row_sums, compute_row_sums, scratch2, context);
475 }
476 }
477 if (proj_clip > 0.0f) {
478 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
479 }
480 } else {
481 std::copy_n(scratch0, n_batch * n_output, output_state);
482 }
483 }
484
485 // Calculates a single LSTM gate, int8x8_16 version.
486 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_16(const int8_t * input,const int8_t * input_to_gate_weights,const int32_t * input_to_gate_bias,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int8_t * output_state,const int8_t * recurrent_to_gate_weights,const int32_t * recurrent_to_gate_bias,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int16_t * cell_state,const int16_t * cell_to_gate_weights,const int32_t cell_to_gate_scale_a,const int32_t cell_to_gate_scale_b,const int16_t * layer_norm_coefficients,const int32_t * layer_norm_bias,const int32_t layer_norm_input_scale_a,const int32_t layer_norm_input_scale_b,const int32_t layer_norm_variance_guard,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,CpuBackendContext * context,int32_t * scratch5)487 void CalculateLstmGateInteger8x8_16(
488 // Input and weights
489 const int8_t* input, const int8_t* input_to_gate_weights,
490 const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
491 const int32_t input_to_gate_scale_b,
492 // Output state and weights
493 const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
494 const int32_t* recurrent_to_gate_bias,
495 const int32_t recurrent_to_gate_scale_a,
496 const int32_t recurrent_to_gate_scale_b,
497 // Cell state and weights
498 const int16_t* cell_state, const int16_t* cell_to_gate_weights,
499 const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
500 // Layer normalization parameters (layer norm LSTM)
501 const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
502 const int32_t layer_norm_input_scale_a,
503 const int32_t layer_norm_input_scale_b,
504 const int32_t layer_norm_variance_guard,
505 // Array sizes
506 const int n_batch, const int n_input, const int n_output, const int n_cell,
507 const TfLiteFusedActivation activation,
508 // Output
509 int16_t* gate,
510 // Parameters for performance optimizations
511 CpuBackendContext* context,
512 // Scratch arrays
513 int32_t* scratch5) {
514 const bool use_peephole = (cell_to_gate_weights != nullptr);
515 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
516
517 // Initialize scratch buffers with zeros. Note that unlike float and hybrid
518 // versions, bias is only used in layer normalization.
519 std::fill_n(gate, n_batch * n_cell, 0);
520 // For each batch and cell: compute input_weight * input.
521 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
522 input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
523 input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
524 context);
525 // Note: no aux_input.
526
527 // For each batch and cell: compute recurrent_weight * output_state.
528 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
529 output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
530 recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
531 n_cell, 0, scratch5, gate, context);
532 // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
533 if (use_peephole) {
534 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
535 cell_to_gate_weights, n_output, cell_state, n_batch,
536 cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
537 }
538 // Do layer normalization (if layer norm LSTM)
539 if (use_layer_norm) {
540 tensor_utils::ApplyLayerNorm(
541 gate, layer_norm_coefficients, layer_norm_bias,
542 layer_norm_input_scale_a, layer_norm_input_scale_b,
543 layer_norm_variance_guard, n_batch, n_cell, gate);
544 }
545 // Apply activation
546 switch (activation) {
547 case kTfLiteActSigmoid:
548 tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
549 break;
550 case kTfLiteActTanh:
551 tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
552 break;
553 default:
554 // Only Sigmoid or Tanh is used.
555 TFLITE_ASSERT_FALSE;
556 }
557 }
558
559 // Updates the LSTM cell state, used by both integer LSTM versions.
560 // Also see UpdateLstmCellFloat.
561 //
562 // Parameters:
563 // - n_batch, n_cell: sizes of vectors
564 // - cell_state: input/output vector, size n_batch*n_cell
565 // - cell_state_scale: scaling factor of cell state.
566 // - input_gate: input vector, size n_batch*n_cell.
567 // - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
568 // - cell_gate: input vector, size n_batch*n_cell.
569 // - use_cifg: use 1-forget_gate instead of input_gate.
570 // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellInteger(int n_batch,int n_cell,int16_t * cell_state,int32_t cell_state_scale,const int16_t * input_gate,int16_t * forget_gate,const int16_t * cell_gate,bool use_cifg,int16_t clip)571 void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
572 int32_t cell_state_scale, const int16_t* input_gate,
573 int16_t* forget_gate, const int16_t* cell_gate,
574 bool use_cifg, int16_t clip) {
575 // Use the forget_gate array as scratch, as input_gate array is not allocated
576 // in CIFG case. (Be careful not to write to the scratch before reading the
577 // forget gate data.)
578 int16_t* scratch = forget_gate;
579
580 tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
581 cell_state);
582 if (use_cifg) {
583 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
584 tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
585 30 + cell_state_scale, scratch);
586 } else {
587 tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
588 30 + cell_state_scale, scratch);
589 }
590 tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, cell_state);
591
592 if (clip > 0) {
593 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
594 }
595 }
596
597 // Calculates the output state tensor of an LSTM step. See Float and hybrid
598 // versions as well.
599 //
600 // Parameters:
601 // - n_batch: batches: the number of distinct vectors in each array.
602 // - n_cell, n_output: sizes of vectors.
603 // - cell_state, output_gate: input vectors, size n_batch*n_cell.
604 // - cell_state_scale: scaling of cell_state.
605 // - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
606 // - hidden_zp: zero_point for cell_state.*output_gate
607 // - projection_weights, proj_scale_[a|b], projection_bias:
608 // constant inputs, describing projection matrix and bias.
609 // - output_state_zp: zero point of output_state. (Input, calibrated value.)
610 // - quantized_proj_clip: if > 0, clip the output of the projection.
611 // - output_state: output vector, size n_batch*n_output. Must be contigous.
612 // - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
613 // - scratch0: scratch area of size n_batch*n_cell
614 // - scratch1: scratch area of size n_batch*n_cell
615 // - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputInteger8x8_16(int n_batch,int n_cell,int n_output,const int16_t * cell_state,int32_t cell_state_scale,const int16_t * output_gate,int32_t hidden_scale_a,int32_t hidden_scale_b,int32_t hidden_zp,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int8_t quantized_proj_clip,int8_t * output_state,CpuBackendContext * context,int16_t * scratch0,int8_t * scratch1,int32_t * scratch2)616 void CalculateLstmOutputInteger8x8_16(
617 int n_batch, int n_cell, int n_output, const int16_t* cell_state,
618 int32_t cell_state_scale, const int16_t* output_gate,
619 int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
620 const int8_t* projection_weights, int32_t proj_scale_a,
621 int32_t proj_scale_b, const int32_t* projection_bias,
622 int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
623 CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
624 int32_t* scratch2) {
625 // Note: unlike float/hybrid, the activation is always Tanh.
626 tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell,
627 scratch0);
628 tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b,
629 n_batch, n_cell, hidden_zp, scratch1);
630
631 const bool use_projection = (projection_weights != nullptr);
632
633 if (use_projection) {
634 // Note: no bias like in float/hybrid
635 std::fill_n(output_state, n_batch * n_output, 0);
636 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
637 scratch1, projection_bias, projection_weights, proj_scale_a,
638 proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
639 output_state, context);
640 if (quantized_proj_clip > 0) {
641 tensor_utils::CwiseClipping(output_state, n_batch * n_output,
642 quantized_proj_clip);
643 }
644 } else {
645 std::copy_n(scratch1, n_batch * n_output, output_state);
646 }
647 }
648
649 // Calculates a single LSTM gate, int8x8_8 version.
650 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_8(const int8_t * input,int32_t input_zp,const int8_t * input_to_gate_weight,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int32_t input_times_weights_scale_a,const int32_t input_times_weights_scale_b,const int32_t input_times_weights_zp,const int8_t * output_state,const int32_t output_state_zp,const int8_t * recurrent_to_gate_weight,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int32_t output_state_times_weights_scale_a,const int32_t output_state_times_weights_scale_b,const int32_t output_state_times_weights_zp,const int16_t * layer_norm_gate_weight,const int32_t layer_norm_gate_scale_a,const int32_t layer_norm_gate_scale_b,const int32_t * gate_bias,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,int8_t * scratch0,int8_t * scratch1)651 void CalculateLstmGateInteger8x8_8(
652 // Inputs and weights
653 const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
654 const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
655 const int32_t input_times_weights_scale_a,
656 const int32_t input_times_weights_scale_b,
657 const int32_t input_times_weights_zp,
658 // Output state and weights
659 const int8_t* output_state, const int32_t output_state_zp,
660 const int8_t* recurrent_to_gate_weight,
661 const int32_t recurrent_to_gate_scale_a,
662 const int32_t recurrent_to_gate_scale_b,
663 const int32_t output_state_times_weights_scale_a,
664 const int32_t output_state_times_weights_scale_b,
665 const int32_t output_state_times_weights_zp,
666 // Layer normalization parameters (layer norm LSTM)
667 const int16_t* layer_norm_gate_weight,
668 const int32_t layer_norm_gate_scale_a,
669 const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
670 // Array sizes
671 const int n_batch, const int n_input, const int n_output, const int n_cell,
672 const TfLiteFusedActivation activation,
673 // Output
674 int16_t* gate,
675 // Scratch arrays, both sized n_batch*n_cell
676 int8_t* scratch0, int8_t* scratch1) {
677 // Multiply input * input_weights => scratch0
678 tensor_utils::MatrixBatchVectorMultiply(
679 input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
680 input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
681 input_times_weights_zp);
682 // Multiply output_state * recurrent_weights => scratch1
683 tensor_utils::MatrixBatchVectorMultiply(
684 output_state, output_state_zp, recurrent_to_gate_weight,
685 recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
686 n_cell, scratch1, output_state_times_weights_zp);
687 // Add scratch0 + scratch1 => gate
688 tensor_utils::TwoGateSaturatingAdd(
689 scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
690 input_times_weights_scale_a, input_times_weights_scale_b,
691 output_state_times_weights_scale_a, output_state_times_weights_scale_b,
692 n_batch, n_cell, gate);
693 // Apply layer normalization.
694 tensor_utils::ApplyLayerNormFloat(
695 gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
696 layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
697 // Apply activation.
698 switch (activation) {
699 case kTfLiteActSigmoid:
700 tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
701 break;
702 case kTfLiteActTanh:
703 tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
704 break;
705 default:
706 // Only Sigmoid or Tanh is used.
707 TFLITE_ASSERT_FALSE;
708 }
709 }
710
711 // Calculates the output state tensor of an LSTM step. See Float and hybrid
712 // versions as well.
713 //
714 // Parameters:
715 // - n_batch: batches: the number of distinct vectors in each array.
716 // - n_cell, n_output: sizes of vectors.
717 // - cell_state, output_gate: input vectors, size n_batch*n_cell.
718 // - projection_weights, proj_scale_[a|b], projection_bias:
719 // constant inputs, describing projection matrix and bias.
720 // - output_state_zp: zero point of the output state.
721 // - quantized_proj_clip: if > 0, clip the output of the projection.
722 // - output_state: output vector, size n_batch*n_output. Must be contigous.
723 // - scratch: scratch area of size n_batch*n_cell
CalculateLstmOutputInteger8x8_8(int n_batch,int n_cell,int n_output,const int16_t * cell_state,const int16_t * output_gate,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int32_t quantized_proj_clip,int8_t * output_state,int16_t * scratch)724 void CalculateLstmOutputInteger8x8_8(
725 int n_batch, int n_cell, int n_output, const int16_t* cell_state,
726 const int16_t* output_gate, const int8_t* projection_weights,
727 int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
728 int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
729 int16_t* scratch) {
730 // Note: unlike float/hybrid, the activation is always Tanh.
731 tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
732 tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15,
733 scratch);
734 // Note: no bias like in float/hybrid
735 tensor_utils::MatrixBatchVectorMultiply(
736 scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
737 n_batch, n_cell, n_output, output_state_zp, output_state);
738 if (quantized_proj_clip > 0) {
739 tensor_utils::CwiseClipping(output_state, n_batch * n_output,
740 quantized_proj_clip);
741 }
742 }
743
744 // Performs an LSTM batch inference step for input specified by input_ptr.
745 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
746 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
747 // parameters:
748 // - params: various LSTM params including activation, clipping, etc.,
749 // - n_batch: size of batch,
750 // - n_cell: number of cells (or units),
751 // - n_input: the input size,
752 // - n_aux_input: the auxiliary input size.
753 // - n_output: the output size.
754 // - output_batch_leading_dim: the leading dimension of the output buffer.
755 //
756 // Input of size 'n_batch * n_input':
757 // input_ptr
758 // Input of size 'n_batch * n_aux_input':
759 // aux_input_ptr - optional (can be nullptr)
760 //
761 // LSTM weights:
762 // Input weights of size 'n_cell * n_input':
763 // input_to_input_weights - optional
764 // input_to_forget_weights
765 // input_to_cell_weights
766 // input_to_output_weights
767 // Auxiliary input weights of size 'n_cell * n_aux_input':
768 // aux_input_to_input_weights - optional
769 // aux_input_to_forget_weights - optional
770 // aux_input_to_cell_weights - optional
771 // aux_input_to_output_weights - optional
772 // Recurrent weights of size 'n_cell * n_output':
773 // recurrent_to_input_weights - optional
774 // recurrent_to_forget_weights
775 // recurrent_to_cell_weights
776 // recurrent_to_input_weights
777 // Peephole weights of size 'n_cell', representing diagonal matrices.
778 // cell_to_input_weights - optional
779 // cell_to_cell_weights - optional
780 // cell_to_output_weights - optional
781 // Projection weights of size 'n_output * n_cell'
782 // projection_weights_ptr - optional
783 // Gate biases of size 'n_cell':
784 // input_gate_bias_ptr - optional
785 // forget_gate_bias_ptr
786 // cell_gate_bias_ptr
787 // output_gate_bias_ptr
788 //
789 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
790 // input_layer_norm_coefficients_ptr - optional
791 // forget_layer_norm_coefficients_ptr - optional
792 // cell_layer_norm_coefficients_ptr - optional
793 // output_layer_norm_coefficients_ptr - optional
794 //
795 // The pointers to the cell and output state and the output are updated.
796 //
797 // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
798 // in batch_major order, and each step processes batch_size many inputs from
799 // input_ptr, and updates batch_size many cell and output states.
800 //
801 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
802 // output tensor, and in most cases will be equal to n_output. It is usually not
803 // when we want to store the LSTM output into a slice of the output tensor, e.g.
804 // for bidirectional LSTMs with merge_outputs. In this case, the batched
805 // operations cannot be used since they assume that the batched outputs are
806 // contiguous, and we manually loop over the batched outputs.
807 // LINT.IfChange
LstmStepFloat(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr)808 inline void LstmStepFloat(
809 const float* input_ptr, const float* input_to_input_weights_ptr,
810 const float* input_to_forget_weights_ptr,
811 const float* input_to_cell_weights_ptr,
812 const float* input_to_output_weights_ptr, const float* aux_input_ptr,
813 const float* aux_input_to_input_weights_ptr,
814 const float* aux_input_to_forget_weights_ptr,
815 const float* aux_input_to_cell_weights_ptr,
816 const float* aux_input_to_output_weights_ptr,
817 const float* recurrent_to_input_weights_ptr,
818 const float* recurrent_to_forget_weights_ptr,
819 const float* recurrent_to_cell_weights_ptr,
820 const float* recurrent_to_output_weights_ptr,
821 const float* cell_to_input_weights_ptr,
822 const float* cell_to_forget_weights_ptr,
823 const float* cell_to_output_weights_ptr,
824 const float* input_layer_norm_coefficients_ptr,
825 const float* forget_layer_norm_coefficients_ptr,
826 const float* cell_layer_norm_coefficients_ptr,
827 const float* output_layer_norm_coefficients_ptr,
828 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
829 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
830 const float* projection_weights_ptr, const float* projection_bias_ptr,
831 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
832 int n_aux_input, int n_output, int output_batch_leading_dim,
833 float* output_state_ptr, float* cell_state_ptr, float* scratch0,
834 float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
835 ruy::profiler::ScopeLabel label("LstmStepFloat");
836 // Since we have already checked that weights are all there or none, we can
837 // check the existence of only one to the get the condition.
838 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
839
840 // Make named scratch buffers.
841 float* input_gate_scratch = scratch0;
842 float* forget_gate_scratch = scratch1;
843 float* cell_gate_scratch = scratch2;
844 float* output_gate_scratch = scratch3;
845
846 // Check if inputs are all zeros so we can skip some computations.
847 const bool is_input_all_zeros =
848 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
849 const bool is_aux_input_all_zeros =
850 (aux_input_ptr == nullptr ||
851 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
852 if (!use_cifg) {
853 // Calculate the input gate. (If not CIFG.)
854 CalculateLstmGateFloat(
855 input_ptr, input_to_input_weights_ptr, aux_input_ptr,
856 aux_input_to_input_weights_ptr, output_state_ptr,
857 recurrent_to_input_weights_ptr, cell_state_ptr,
858 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
859 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
860 /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
861 is_input_all_zeros, is_aux_input_all_zeros);
862 }
863 // Calculate the forget gate.
864 CalculateLstmGateFloat(
865 input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
866 aux_input_to_forget_weights_ptr, output_state_ptr,
867 recurrent_to_forget_weights_ptr, cell_state_ptr,
868 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
869 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
870 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
871 is_aux_input_all_zeros);
872 // Calculate the cell update gate.
873 CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
874 aux_input_to_cell_weights_ptr, output_state_ptr,
875 recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
876 /*cell_to_gate_weights=*/nullptr,
877 cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
878 n_batch, n_input, n_aux_input, n_output, n_cell,
879 params->activation, cell_gate_scratch,
880 is_input_all_zeros, is_aux_input_all_zeros);
881 // Update the cell state.
882 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
883 forget_gate_scratch, cell_gate_scratch, use_cifg,
884 params->cell_clip);
885 // Calculate output gate.
886 CalculateLstmGateFloat(
887 input_ptr, input_to_output_weights_ptr, aux_input_ptr,
888 aux_input_to_output_weights_ptr, output_state_ptr,
889 recurrent_to_output_weights_ptr, cell_state_ptr,
890 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
891 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
892 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
893 is_aux_input_all_zeros);
894 // Update the output state.
895 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
896 output_gate_scratch, params->activation,
897 projection_weights_ptr, projection_bias_ptr,
898 params->proj_clip, output_state_ptr, scratch2);
899 // Copy output state to the output. Note that the output's rows may not be
900 // contiguous (output_batch_leading_dim != n_output).
901 for (int b = 0; b < n_batch; b++) {
902 std::copy_n(output_state_ptr + b * n_output, n_output,
903 output_ptr + b * output_batch_leading_dim);
904 }
905 }
906 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
907 // ../experimental/kernels/fp16/lstm_eval.cc)
908
909 // Same as above but with quantized weight matrices. In detail:
910 // Input of size 'n_batch * n_input':
911 // input_ptr
912 // Input of size 'n_batch * n_aux_input':
913 // aux_input_ptr - optional (can be nullptr)
914 //
915 // LSTM weights:
916 // Quantized input weights of size 'n_cell * n_input':
917 // input_to_input_weights - optional
918 // input_to_forget_weights
919 // input_to_cell_weights
920 // input_to_input_weights
921 // Quantized auxiliary input weights of size 'n_cell * n_aux_input':
922 // aux_input_to_input_weights - optional
923 // aux_input_to_forget_weights - optional
924 // aux_input_to_cell_weights - optional
925 // aux_input_to_output_weights - optional
926 // Quantized recurrent weights of size 'n_cell * n_output':
927 // recurrent_to_input_weights - optional
928 // recurrent_to_forget_weights
929 // recurrent_to_cell_weights
930 // recurrent_to_input_weights
931 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
932 // cell_to_input_weights - optional
933 // cell_to_cell_weights - optional
934 // cell_to_output_weights - optional
935 // Quantized projection weights of size 'n_output * n_cell'
936 // projection_weights_ptr - optional
937 // Weight scales (scalars) for each of the weights above.
938 // input_to_input_weights_scale - optional
939 // input_to_forget_weights_scale
940 // input_to_cell_weights_scale
941 // input_to_output_weights_scale
942 // aux_input_to_input_weights_scale - optional
943 // aux_input_to_forget_weights_scale - optional
944 // aux_input_to_cell_weights_scale - optional
945 // aux_input_to_output_weights_scale - optional
946 // recurrent_to_input_weights_scale - optional
947 // recurrent_to_forget_weights_scale
948 // recurrent_to_cell_weights_scale
949 // recurrent_to_output_weights_scale
950 // cell_to_input_weights_scale,
951 // cell_to_forget_weights_scale,
952 // cell_to_output_weights_scale,
953 // projection_weights_scale - optional
954 // Gate biases of size 'n_cell':
955 // input_gate_bias_ptr - optional
956 // forget_gate_bias_ptr
957 // cell_gate_bias_ptr
958 // output_gate_bias_ptr
959 //
960 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
961 // input_layer_norm_coefficients_ptr - optional
962 // forget_layer_norm_coefficients_ptr - optional
963 // cell_layer_norm_coefficients_ptr - optional
964 // output_layer_norm_coefficients_ptr - optional
965 //
966 // Temporary pre-allocated storage for quantized values:
967 // quantized_input_ptr (same size as input_ptr)
968 // quantized_output_state_ptr (same size as output_state_ptr)
969 // quantized_output_scratch (same size as cell_state_ptr)
970 // Temporary pre-allocated storage for recovered values:
971 // recovered_cell_weights (same size as cell_to_*_weights)
972 //
973 // Outputs:
974 // output_state_ptr - size 'n_batch * n_output'
975 // cell_state_ptr - size 'n_batch * n_cell'
976 // output_ptr - size 'n_batch * output_batch_leading_dim'
LstmStepHybrid(const float * input_ptr,const int8_t * input_to_input_weights_ptr,const uint8_t * input_to_input_weights_ledger_ptr,float input_to_input_weights_scale,const int8_t * input_to_forget_weights_ptr,const uint8_t * input_to_forget_weights_ledger_ptr,float input_to_forget_weights_scale,const int8_t * input_to_cell_weights_ptr,const uint8_t * input_to_cell_weights_ledger_ptr,float input_to_cell_weights_scale,const int8_t * input_to_output_weights_ptr,const uint8_t * input_to_output_weights_ledger_ptr,float input_to_output_weights_scale,const float * aux_input_ptr,const int8_t * aux_input_to_input_weights_ptr,float aux_input_to_input_weights_scale,const int8_t * aux_input_to_forget_weights_ptr,float aux_input_to_forget_weights_scale,const int8_t * aux_input_to_cell_weights_ptr,float aux_input_to_cell_weights_scale,const int8_t * aux_input_to_output_weights_ptr,float aux_input_to_output_weights_scale,const int8_t * recurrent_to_input_weights_ptr,const uint8_t * recurrent_to_input_weights_ledger_ptr,float recurrent_to_input_weights_scale,const int8_t * recurrent_to_forget_weights_ptr,const uint8_t * recurrent_to_forget_weights_ledger_ptr,float recurrent_to_forget_weights_scale,const int8_t * recurrent_to_cell_weights_ptr,const uint8_t * recurrent_to_cell_weights_ledger_ptr,float recurrent_to_cell_weights_scale,const int8_t * recurrent_to_output_weights_ptr,const uint8_t * recurrent_to_output_weights_ledger_ptr,float recurrent_to_output_weights_scale,const int8_t * cell_to_input_weights_ptr,float cell_to_input_weights_scale,const int8_t * cell_to_forget_weights_ptr,float cell_to_forget_weights_scale,const int8_t * cell_to_output_weights_ptr,float cell_to_output_weights_scale,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const int8_t * projection_weights_ptr,const uint8_t * projection_weights_ledger_ptr,float projection_weights_scale,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * input_sf,float * aux_input_sf,float * output_state_sf,float * scaling_factors_scratch,float * recovered_cell_weights,int8_t * quantized_input_ptr,int8_t * quantized_aux_input_ptr,int8_t * quantized_output_state_ptr,int8_t * quantized_output_scratch,float * output_state_ptr,float * cell_state_ptr,int32_t * accum_scratch_ptr,float * output_ptr,int32_t * input_zp,int32_t * aux_input_zp,int32_t * output_state_zp,int32_t * row_sums,int row_sums_size,bool * compute_row_sums,bool asymmetric_quantize_inputs,CpuBackendContext * context)977 inline void LstmStepHybrid(
978 const float* input_ptr, const int8_t* input_to_input_weights_ptr,
979 const uint8_t* input_to_input_weights_ledger_ptr,
980 float input_to_input_weights_scale,
981 const int8_t* input_to_forget_weights_ptr,
982 const uint8_t* input_to_forget_weights_ledger_ptr,
983 float input_to_forget_weights_scale,
984 const int8_t* input_to_cell_weights_ptr,
985 const uint8_t* input_to_cell_weights_ledger_ptr,
986 float input_to_cell_weights_scale,
987 const int8_t* input_to_output_weights_ptr,
988 const uint8_t* input_to_output_weights_ledger_ptr,
989 float input_to_output_weights_scale, const float* aux_input_ptr,
990 const int8_t* aux_input_to_input_weights_ptr,
991 float aux_input_to_input_weights_scale,
992 const int8_t* aux_input_to_forget_weights_ptr,
993 float aux_input_to_forget_weights_scale,
994 const int8_t* aux_input_to_cell_weights_ptr,
995 float aux_input_to_cell_weights_scale,
996 const int8_t* aux_input_to_output_weights_ptr,
997 float aux_input_to_output_weights_scale,
998 const int8_t* recurrent_to_input_weights_ptr,
999 const uint8_t* recurrent_to_input_weights_ledger_ptr,
1000 float recurrent_to_input_weights_scale,
1001 const int8_t* recurrent_to_forget_weights_ptr,
1002 const uint8_t* recurrent_to_forget_weights_ledger_ptr,
1003 float recurrent_to_forget_weights_scale,
1004 const int8_t* recurrent_to_cell_weights_ptr,
1005 const uint8_t* recurrent_to_cell_weights_ledger_ptr,
1006 float recurrent_to_cell_weights_scale,
1007 const int8_t* recurrent_to_output_weights_ptr,
1008 const uint8_t* recurrent_to_output_weights_ledger_ptr,
1009 float recurrent_to_output_weights_scale,
1010 const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
1011 const int8_t* cell_to_forget_weights_ptr,
1012 float cell_to_forget_weights_scale,
1013 const int8_t* cell_to_output_weights_ptr,
1014 float cell_to_output_weights_scale,
1015 const float* input_layer_norm_coefficients_ptr,
1016 const float* forget_layer_norm_coefficients_ptr,
1017 const float* cell_layer_norm_coefficients_ptr,
1018 const float* output_layer_norm_coefficients_ptr,
1019 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
1020 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
1021 const int8_t* projection_weights_ptr,
1022 const uint8_t* projection_weights_ledger_ptr,
1023 float projection_weights_scale, const float* projection_bias_ptr,
1024 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
1025 int n_aux_input, int n_output, int output_batch_leading_dim,
1026 float* scratch0, float* scratch1, float* scratch2, float* scratch3,
1027 float* input_sf, float* aux_input_sf, float* output_state_sf,
1028 float* scaling_factors_scratch, float* recovered_cell_weights,
1029 int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
1030 int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
1031 float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
1032 float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
1033 int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
1034 bool* compute_row_sums, bool asymmetric_quantize_inputs,
1035 CpuBackendContext* context) {
1036 ruy::profiler::ScopeLabel label("LstmStepHybrid");
1037 // Since we have already checked that weights are all there or none, we
1038 // can check the existence of only one to the get the condition.
1039 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
1040 // Make named scratch buffers for the different gates.
1041 float* input_gate_scratch = scratch0;
1042 float* forget_gate_scratch = scratch1;
1043 float* cell_gate_scratch = scratch2;
1044 float* output_gate_scratch = scratch3;
1045
1046 int32_t* input_to_input_row_sums = nullptr;
1047 int32_t* input_to_forget_row_sums = nullptr;
1048 int32_t* input_to_cell_row_sums = nullptr;
1049 int32_t* input_to_output_row_sums = nullptr;
1050 int32_t* aux_input_to_input_row_sums = nullptr;
1051 int32_t* aux_input_to_forget_row_sums = nullptr;
1052 int32_t* aux_input_to_cell_row_sums = nullptr;
1053 int32_t* aux_input_to_output_row_sums = nullptr;
1054 int32_t* recurrent_to_input_row_sums = nullptr;
1055 int32_t* recurrent_to_forget_row_sums = nullptr;
1056 int32_t* recurrent_to_cell_row_sums = nullptr;
1057 int32_t* recurrent_to_output_row_sums = nullptr;
1058 int32_t* projection_weights_row_sums = nullptr;
1059
1060 if (asymmetric_quantize_inputs) {
1061 int num_row_sums = use_cifg ? 6 : 8;
1062 if (aux_input_ptr != nullptr) {
1063 num_row_sums += use_cifg ? 3 : 4;
1064 }
1065 if (projection_weights_ptr != nullptr) {
1066 num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
1067 }
1068 TF_LITE_ASSERT(row_sums_size == num_row_sums);
1069 input_to_input_row_sums = row_sums;
1070 input_to_forget_row_sums =
1071 use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
1072 input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
1073 input_to_output_row_sums = input_to_cell_row_sums + n_cell;
1074 if (aux_input_ptr != nullptr) {
1075 aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
1076 aux_input_to_forget_row_sums = use_cifg
1077 ? aux_input_to_input_row_sums
1078 : aux_input_to_input_row_sums + n_cell;
1079 aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
1080 aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
1081 }
1082 recurrent_to_input_row_sums = aux_input_ptr
1083 ? aux_input_to_output_row_sums + n_cell
1084 : input_to_output_row_sums + n_cell;
1085 recurrent_to_forget_row_sums = use_cifg
1086 ? recurrent_to_input_row_sums
1087 : recurrent_to_input_row_sums + n_cell;
1088 recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
1089 recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
1090 if (projection_weights_ptr != nullptr) {
1091 projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
1092 }
1093 if (*compute_row_sums) {
1094 ComputeRowSums(
1095 input_to_input_row_sums, input_to_forget_row_sums,
1096 input_to_cell_row_sums, input_to_output_row_sums,
1097 aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
1098 aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
1099 recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
1100 recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
1101 projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
1102 n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
1103 input_to_cell_weights_ptr, input_to_output_weights_ptr,
1104 aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1105 aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1106 recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
1107 recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
1108 projection_weights_ptr, use_cifg, aux_input_ptr);
1109 *compute_row_sums = false;
1110 }
1111 }
1112
1113 // Check if inputs are all zeros so we can skip some computations.
1114 const bool is_input_all_zeros =
1115 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
1116 const bool is_aux_input_all_zeros =
1117 (aux_input_ptr == nullptr ||
1118 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
1119 const bool is_output_state_all_zeros =
1120 tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
1121 // Quantize inputs.
1122 if (!is_input_all_zeros) {
1123 tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
1124 quantized_input_ptr, input_sf, input_zp,
1125 asymmetric_quantize_inputs);
1126 }
1127 if (!is_aux_input_all_zeros) {
1128 tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
1129 quantized_aux_input_ptr, aux_input_sf,
1130 aux_input_zp, asymmetric_quantize_inputs);
1131 }
1132 if (!is_output_state_all_zeros) {
1133 tensor_utils::BatchQuantizeFloats(
1134 output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
1135 output_state_sf, output_state_zp, asymmetric_quantize_inputs);
1136 }
1137 if (!use_cifg) {
1138 // Calculate the input gate. (If not CIFG.)
1139 CalculateLstmGateHybrid(
1140 quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
1141 input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
1142 input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
1143 aux_input_zp, aux_input_to_input_weights_ptr,
1144 aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
1145 quantized_output_state_ptr, output_state_sf, output_state_zp,
1146 recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
1147 recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
1148 cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
1149 input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
1150 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1151 input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1152 is_output_state_all_zeros, compute_row_sums, context,
1153 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1154 }
1155 // Calculate the forget gate.
1156 CalculateLstmGateHybrid(
1157 quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
1158 input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
1159 input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
1160 aux_input_zp, aux_input_to_forget_weights_ptr,
1161 aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
1162 quantized_output_state_ptr, output_state_sf, output_state_zp,
1163 recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
1164 recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
1165 cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1166 forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
1167 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1168 forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1169 is_output_state_all_zeros, compute_row_sums, context,
1170 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1171 // Calculate the cell update gate.
1172 CalculateLstmGateHybrid(
1173 quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
1174 input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
1175 input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
1176 aux_input_zp, aux_input_to_cell_weights_ptr,
1177 aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
1178 quantized_output_state_ptr, output_state_sf, output_state_zp,
1179 recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
1180 recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
1181 /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
1182 /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
1183 cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
1184 params->activation, cell_gate_scratch, is_input_all_zeros,
1185 is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
1186 context, scaling_factors_scratch, recovered_cell_weights,
1187 accum_scratch_ptr);
1188 // Update the cell state.
1189 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
1190 forget_gate_scratch, cell_gate_scratch, use_cifg,
1191 params->cell_clip);
1192 // Calculate the output gate.
1193 CalculateLstmGateHybrid(
1194 quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
1195 input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
1196 input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
1197 aux_input_zp, aux_input_to_output_weights_ptr,
1198 aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
1199 quantized_output_state_ptr, output_state_sf, output_state_zp,
1200 recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
1201 recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
1202 cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
1203 output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
1204 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1205 output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1206 is_output_state_all_zeros, compute_row_sums, context,
1207 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1208 // Update the output state.
1209 CalculateLstmOutputHybrid(
1210 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1211 params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
1212 projection_weights_scale, projection_bias_ptr, params->proj_clip,
1213 output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
1214 compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
1215 input_zp, accum_scratch_ptr);
1216 // Copy output state to the output. Note that the output's rows may not be
1217 // contiguous (output_batch_leading_dim != n_output).
1218 for (int b = 0; b < n_batch; b++) {
1219 std::copy_n(output_state_ptr + b * n_output, n_output,
1220 output_ptr + b * output_batch_leading_dim);
1221 }
1222 }
1223
1224 // Fully quantized lstm kernel for 16 bit gate matmul output.
1225 //
1226 // Input tensor of size n_batch * n_input:
1227 // input_ptr
1228 //
1229 // LSTM weights:
1230 // Quantized input weights of size 'n_cell * n_input':
1231 // input_to_input_weight_ptr - optional
1232 // input_to_forget_weight_ptr - optional
1233 // input_to_cell_weight_ptr - optional
1234 // input_to_output_weight_ptr - optional
1235 //
1236 // Quantized recurrent weights of size 'n_cell * n_output':
1237 // recurrent_to_input_weight_ptr - optional
1238 // recurrent_to_forget_weights_ptr
1239 // recurrent_to_cell_weights_ptr
1240 // recurrent_to_input_weights_ptr
1241 //
1242 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1243 // cell_to_input_weights - optional
1244 // cell_to_cell_weights - optional
1245 // cell_to_output_weights - optional
1246 //
1247 // Quantized projection weights of size 'n_output * n_cell'
1248 // projection_weight_ptr - optional
1249 //
1250 // Weight scales (scalars) for each of the weights above.
1251 // effective_input_to_input_scale_a - optional
1252 // effective_input_to_input_scale_b - optional
1253 // effective_input_to_forget_scale_a
1254 // effective_input_to_forget_scale_b
1255 // effective_input_to_cell_scale_a
1256 // effective_input_to_cell_scale_b
1257 // effective_input_to_output_scale_a
1258 // effective_input_to_output_scale_b
1259 // effective_recurrent_to_input_scale_a - optional
1260 // effective_recurrent_to_input_scale_b - optional
1261 // effective_recurrent_to_forget_scale_a
1262 // effective_recurrent_to_forget_scale_b
1263 // effective_recurrent_to_cell_scale_a
1264 // effective_recurrent_to_cell_scale_b
1265 // effective_recurrent_to_output_scale_a
1266 // effective_recurrent_to_output_scale_b
1267 // effective_proj_scale_a - optional
1268 // effective_proj_scale_b - optional
1269 //
1270 // Gate biases of size 'n_cell':
1271 // input_gate_bias_ptr - optional
1272 // forget_gate_bias_ptr
1273 // cell_gate_bias_ptr
1274 // output_gate_bias_ptr
1275 //
1276 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1277 // layer_norm_input_weight_ptr - optional
1278 // layer_norm_forget_weight_ptr - optional
1279 // layer_norm_cell_weight_ptr - optional
1280 // layer_norm_output_weight_ptr - optional
1281 //
1282 // Layer norm scales of size 'n_cell'.
1283 // layer_norm_input_scale_a - optional
1284 // layer_norm_input_scale_b - optional
1285 // layer_norm_forget_scale_a - optional
1286 // layer_norm_forget_scale_b - optional
1287 // layer_norm_cell_scale_a - optional
1288 // layer_norm_cell_scale_b - optional
1289 // layer_norm_output_scale_a - optional
1290 // layer_norm_output_scale_b - optional
1291 //
1292 // Scalar values:
1293 // quantized_cell_clip: quantized clip value for cell.
1294 // quantized_proj_clip: quantized clip value for projection.
1295 // cell_state_scale: the power of two scale for cell state.
1296 //
1297 // Zero points:
1298 // output_state_zp: zero point of output state
1299 // hidden_zp: zero point for hidden state.
1300 //
1301 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1302 // n_batch.
1303 // scratch0
1304 // scratch1
1305 // scratch2
1306 // scratch3
1307 // scratch4
1308 // scratch5: this scratch buffer is created purely for optimizing the
1309 // MatrixBatchVectorMultiplyAccumulate.
1310 //
1311 // Outputs:
1312 // output_state_ptr - size 'n_batch * n_output'
1313 // cell_state_ptr - size 'n_batch * n_cell'
1314 // output_ptr - size 'n_batch * n_output'
1315 // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
LstmStepInteger8x8_16(const int8_t * input_ptr,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int16_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int16_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int16_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,int32_t hidden_zp,int32_t effective_hidden_scale_a,int32_t effective_hidden_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int32_t cell_state_scale,int32_t input_variance_guard,int32_t forget_variance_guard,int32_t cell_variance_guard,int32_t output_variance_guard,const int32_t * input_to_forget_effective_bias,const int32_t * recurrent_to_forget_effective_bias,const int32_t * input_to_cell_effective_bias,const int32_t * recurrent_to_cell_effective_bias,const int32_t * input_to_output_effective_bias,const int32_t * recurrent_to_output_effective_bias,const int32_t * input_to_input_effective_bias,const int32_t * recurrent_to_input_effective_bias,const int32_t * projection_effective_bias,int n_batch,int n_cell,int n_input,int n_output,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int16_t * scratch0,int16_t * scratch1,int16_t * scratch2,int16_t * scratch3,int8_t * scratch4,int32_t * scratch5,CpuBackendContext * context)1316 inline void LstmStepInteger8x8_16(
1317 const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
1318 int32_t effective_input_to_input_scale_a,
1319 int32_t effective_input_to_input_scale_b,
1320 const int8_t* input_to_forget_weight_ptr,
1321 int32_t effective_input_to_forget_scale_a,
1322 int32_t effective_input_to_forget_scale_b,
1323 const int8_t* input_to_cell_weight_ptr,
1324 int32_t effective_input_to_cell_scale_a,
1325 int32_t effective_input_to_cell_scale_b,
1326 const int8_t* input_to_output_weight_ptr,
1327 int32_t effective_input_to_output_scale_a,
1328 int32_t effective_input_to_output_scale_b,
1329 const int8_t* recurrent_to_input_weight_ptr,
1330 int32_t effective_recurrent_to_input_scale_a,
1331 int32_t effective_recurrent_to_input_scale_b,
1332 const int8_t* recurrent_to_forget_weight_ptr,
1333 int32_t effective_recurrent_to_forget_scale_a,
1334 int32_t effective_recurrent_to_forget_scale_b,
1335 const int8_t* recurrent_to_cell_weight_ptr,
1336 int32_t effective_recurrent_to_cell_scale_a,
1337 int32_t effective_recurrent_to_cell_scale_b,
1338 const int8_t* recurrent_to_output_weight_ptr,
1339 int32_t effective_recurrent_to_output_scale_a,
1340 int32_t effective_recurrent_to_output_scale_b,
1341 const int16_t* cell_to_input_weight_ptr,
1342 int32_t effective_cell_to_input_scale_a,
1343 int32_t effective_cell_to_input_scale_b,
1344 const int16_t* cell_to_forget_weight_ptr,
1345 int32_t effective_cell_to_forget_scale_a,
1346 int32_t effective_cell_to_forget_scale_b,
1347 const int16_t* cell_to_output_weight_ptr,
1348 int32_t effective_cell_to_output_scale_a,
1349 int32_t effective_cell_to_output_scale_b,
1350 const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1351 int32_t effective_proj_scale_b, int32_t hidden_zp,
1352 int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
1353 const int16_t* layer_norm_input_weight_ptr,
1354 int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1355 const int16_t* layer_norm_forget_weight_ptr,
1356 int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1357 const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1358 int32_t layer_norm_cell_scale_b,
1359 const int16_t* layer_norm_output_weight_ptr,
1360 int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1361 const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1362 const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1363 int16_t quantized_cell_clip, int8_t quantized_proj_clip,
1364 int32_t cell_state_scale, int32_t input_variance_guard,
1365 int32_t forget_variance_guard, int32_t cell_variance_guard,
1366 int32_t output_variance_guard,
1367 const int32_t* input_to_forget_effective_bias,
1368 const int32_t* recurrent_to_forget_effective_bias,
1369 const int32_t* input_to_cell_effective_bias,
1370 const int32_t* recurrent_to_cell_effective_bias,
1371 const int32_t* input_to_output_effective_bias,
1372 const int32_t* recurrent_to_output_effective_bias,
1373 const int32_t* input_to_input_effective_bias,
1374 const int32_t* recurrent_to_input_effective_bias,
1375 const int32_t* projection_effective_bias, int n_batch, int n_cell,
1376 int n_input, int n_output, int8_t* output_state_ptr,
1377 int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1378 int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1379 int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
1380 ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
1381 // Make named scratch buffers for the different gates.
1382 int16_t* input_gate_scratch = scratch0;
1383 int16_t* forget_gate_scratch = scratch1;
1384 int16_t* cell_gate_scratch = scratch2;
1385 int16_t* output_gate_scratch = scratch3;
1386
1387 // Since we have already checked that weights are all there or none, we
1388 // can check the existence of only one to the get the condition.
1389 const bool use_cifg = (input_to_input_weight_ptr == nullptr);
1390
1391 // Check for nullptrs.
1392 TFLITE_DCHECK(input_to_forget_effective_bias);
1393 TFLITE_DCHECK(recurrent_to_forget_effective_bias);
1394 TFLITE_DCHECK(input_to_cell_effective_bias);
1395 TFLITE_DCHECK(recurrent_to_cell_effective_bias);
1396 TFLITE_DCHECK(input_to_output_effective_bias);
1397 TFLITE_DCHECK(recurrent_to_output_effective_bias);
1398 if (!use_cifg) {
1399 TFLITE_DCHECK(input_to_input_effective_bias);
1400 TFLITE_DCHECK(recurrent_to_input_effective_bias);
1401 }
1402 const bool use_projection = (projection_weight_ptr != nullptr);
1403 if (use_projection) {
1404 TFLITE_DCHECK(projection_effective_bias);
1405 }
1406 if (!use_cifg) {
1407 // Calculate the input gate. (If not CIFG.)
1408 CalculateLstmGateInteger8x8_16(
1409 input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
1410 effective_input_to_input_scale_a, effective_input_to_input_scale_b,
1411 output_state_ptr, recurrent_to_input_weight_ptr,
1412 recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
1413 effective_recurrent_to_input_scale_b, cell_state_ptr,
1414 cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
1415 effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
1416 input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
1417 input_variance_guard, n_batch, n_input, n_output, n_cell,
1418 kTfLiteActSigmoid, input_gate_scratch, context, scratch5);
1419 }
1420 // Calculate the forget gate.
1421 CalculateLstmGateInteger8x8_16(
1422 input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
1423 effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1424 output_state_ptr, recurrent_to_forget_weight_ptr,
1425 recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
1426 effective_recurrent_to_forget_scale_b, cell_state_ptr,
1427 cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
1428 effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
1429 forget_gate_bias_ptr, layer_norm_forget_scale_a,
1430 layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
1431 n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context,
1432 scratch5);
1433 // Calculate the cell update gate.
1434 CalculateLstmGateInteger8x8_16(
1435 input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
1436 effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1437 output_state_ptr, recurrent_to_cell_weight_ptr,
1438 recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
1439 effective_recurrent_to_cell_scale_b, cell_state_ptr,
1440 /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
1441 /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
1442 cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
1443 cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
1444 cell_gate_scratch, context, scratch5);
1445 // Update the cell state.
1446 UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
1447 input_gate_scratch, forget_gate_scratch,
1448 cell_gate_scratch, use_cifg, quantized_cell_clip);
1449 // Calculate the output gate.
1450 CalculateLstmGateInteger8x8_16(
1451 input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
1452 effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1453 output_state_ptr, recurrent_to_output_weight_ptr,
1454 recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
1455 effective_recurrent_to_output_scale_b, cell_state_ptr,
1456 cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
1457 effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
1458 output_gate_bias_ptr, layer_norm_output_scale_a,
1459 layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
1460 n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context,
1461 scratch5);
1462 // Update the output state.
1463 CalculateLstmOutputInteger8x8_16(
1464 n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
1465 output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
1466 hidden_zp, projection_weight_ptr, effective_proj_scale_a,
1467 effective_proj_scale_b, projection_effective_bias, output_state_zp,
1468 quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
1469 scratch5);
1470 // Copy output state to the output. Note that unlike float or hybrid, output
1471 // is always contiguous.
1472 std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1473 }
1474
1475 // Fully quantized lstm kernel for 8 bit gate matmul output.
1476 //
1477 // Input tensor of size n_batch * n_input:
1478 // input_ptr
1479 //
1480 // LSTM weights:
1481 // Quantized input weights of size 'n_cell * n_input':
1482 // input_to_input_weight_ptr - optional
1483 // input_to_forget_weight_ptr - optional
1484 // input_to_cell_weight_ptr - optional
1485 // input_to_output_weight_ptr - optional
1486 //
1487 // Quantized recurrent weights of size 'n_cell * n_output':
1488 // recurrent_to_input_weight_ptr - optional
1489 // recurrent_to_forget_weights_ptr
1490 // recurrent_to_cell_weights_ptr
1491 // recurrent_to_input_weights_ptr
1492 //
1493 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1494 // cell_to_input_weights - optional
1495 // cell_to_cell_weights - optional
1496 // cell_to_output_weights - optional
1497 //
1498 // Quantized projection weights of size 'n_output * n_cell'
1499 // projection_weight_ptr - optional
1500 //
1501 // Weight scales (scalars) for each of the weights above.
1502 // effective_input_to_input_scale_a - optional
1503 // effective_input_to_input_scale_b - optional
1504 // effective_input_to_forget_scale_a
1505 // effective_input_to_forget_scale_b
1506 // effective_input_to_cell_scale_a
1507 // effective_input_to_cell_scale_b
1508 // effective_input_to_output_scale_a
1509 // effective_input_to_output_scale_b
1510 // effective_recurrent_to_input_scale_a - optional
1511 // effective_recurrent_to_input_scale_b - optional
1512 // effective_recurrent_to_forget_scale_a
1513 // effective_recurrent_to_forget_scale_b
1514 // effective_recurrent_to_cell_scale_a
1515 // effective_recurrent_to_cell_scale_b
1516 // effective_recurrent_to_output_scale_a
1517 // effective_recurrent_to_output_scale_b
1518 // effective_proj_scale_a - optional
1519 // effective_proj_scale_b - optional
1520 //
1521 // Gate biases of size 'n_cell':
1522 // input_gate_bias_ptr - optional
1523 // forget_gate_bias_ptr
1524 // cell_gate_bias_ptr
1525 // output_gate_bias_ptr
1526 //
1527 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1528 // layer_norm_input_weight_ptr - optional
1529 // layer_norm_forget_weight_ptr - optional
1530 // layer_norm_cell_weight_ptr - optional
1531 // layer_norm_output_weight_ptr - optional
1532 //
1533 // Layer norm scales of size 'n_cell'.
1534 // layer_norm_input_scale_a - optional
1535 // layer_norm_input_scale_b - optional
1536 // layer_norm_forget_scale_a - optional
1537 // layer_norm_forget_scale_b - optional
1538 // layer_norm_cell_scale_a - optional
1539 // layer_norm_cell_scale_b - optional
1540 // layer_norm_output_scale_a - optional
1541 // layer_norm_output_scale_b - optional
1542 //
1543 // Scalar values:
1544 // quantized_cell_clip: quantized clip value for cell.
1545 // quantized_proj_clip: quantized clip value for projection.
1546 // cell_state_scale: the power of two scale for cell state.
1547 //
1548 // Zero points:
1549 // output_state_zp: zero point of output state.
1550 // hidden_zp: zero point for hidden state.
1551 //
1552 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1553 // n_batch.
1554 // scratch0
1555 // scratch1
1556 // scratch2
1557 // scratch3
1558 // scratch4
1559 // scratch5
1560 // scratch6
1561 // scratch7
1562 //
1563 // Outputs:
1564 // output_state_ptr - size 'n_batch * n_output'
1565 // cell_state_ptr - size 'n_batch * n_cell'
1566 // output_ptr - size 'n_batch * n_output'
1567 // TODO(b/148688698): Move zero point calculation into Prepare().
1568 // TODO(b/159947023): scratch5 is unused, remove.
LstmStepInteger8x8_8(const int8_t * input_ptr,int32_t input_zp,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int8_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int8_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int8_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,const int32_t * projection_bias_ptr,const TfLiteLSTMParams * params,const int32_t * intermediate_scale_a,const int32_t * intermediate_scale_b,const int32_t * intermediate_zp,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int n_batch,int n_cell,int n_input,int n_output,int output_batch_leading_dim,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int8_t * scratch0,int8_t * scratch1,int16_t * scratch2,int16_t * scratch3,int16_t * scratch4,int16_t * scratch5,int16_t * scratch6,int16_t * scratch7)1569 inline void LstmStepInteger8x8_8(
1570 const int8_t* input_ptr, int32_t input_zp,
1571 const int8_t* input_to_input_weight_ptr,
1572 int32_t effective_input_to_input_scale_a,
1573 int32_t effective_input_to_input_scale_b,
1574 const int8_t* input_to_forget_weight_ptr,
1575 int32_t effective_input_to_forget_scale_a,
1576 int32_t effective_input_to_forget_scale_b,
1577 const int8_t* input_to_cell_weight_ptr,
1578 int32_t effective_input_to_cell_scale_a,
1579 int32_t effective_input_to_cell_scale_b,
1580 const int8_t* input_to_output_weight_ptr,
1581 int32_t effective_input_to_output_scale_a,
1582 int32_t effective_input_to_output_scale_b,
1583 const int8_t* recurrent_to_input_weight_ptr,
1584 int32_t effective_recurrent_to_input_scale_a,
1585 int32_t effective_recurrent_to_input_scale_b,
1586 const int8_t* recurrent_to_forget_weight_ptr,
1587 int32_t effective_recurrent_to_forget_scale_a,
1588 int32_t effective_recurrent_to_forget_scale_b,
1589 const int8_t* recurrent_to_cell_weight_ptr,
1590 int32_t effective_recurrent_to_cell_scale_a,
1591 int32_t effective_recurrent_to_cell_scale_b,
1592 const int8_t* recurrent_to_output_weight_ptr,
1593 int32_t effective_recurrent_to_output_scale_a,
1594 int32_t effective_recurrent_to_output_scale_b,
1595 const int8_t* cell_to_input_weight_ptr,
1596 int32_t effective_cell_to_input_scale_a,
1597 int32_t effective_cell_to_input_scale_b,
1598 const int8_t* cell_to_forget_weight_ptr,
1599 int32_t effective_cell_to_forget_scale_a,
1600 int32_t effective_cell_to_forget_scale_b,
1601 const int8_t* cell_to_output_weight_ptr,
1602 int32_t effective_cell_to_output_scale_a,
1603 int32_t effective_cell_to_output_scale_b,
1604 const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1605 int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
1606 int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1607 const int16_t* layer_norm_forget_weight_ptr,
1608 int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1609 const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1610 int32_t layer_norm_cell_scale_b,
1611 const int16_t* layer_norm_output_weight_ptr,
1612 int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1613 const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1614 const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1615 const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
1616 const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
1617 const int32_t* intermediate_zp, int16_t quantized_cell_clip,
1618 int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
1619 int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
1620 int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1621 int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1622 int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
1623 int16_t* scratch7) {
1624 // TODO(b/159066113): scratch5 is unused, remove.
1625
1626 ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
1627 // Make named scratch buffers for the different gates.
1628 int16_t* forget_gate_scratch = scratch2;
1629 int16_t* cell_gate_scratch = scratch3;
1630 int16_t* output_gate_scratch = scratch4;
1631 // no-CIFG is not supported here
1632
1633 // Calculate the forget gate.
1634 CalculateLstmGateInteger8x8_8(
1635 input_ptr, input_zp, input_to_forget_weight_ptr,
1636 effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1637 intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
1638 output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
1639 effective_recurrent_to_forget_scale_a,
1640 effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
1641 intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
1642 layer_norm_forget_scale_a, layer_norm_forget_scale_b,
1643 forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1644 kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
1645 // Calculate the cell update gate.
1646 CalculateLstmGateInteger8x8_8(
1647 input_ptr, input_zp, input_to_cell_weight_ptr,
1648 effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1649 intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
1650 output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
1651 effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
1652 intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
1653 layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
1654 layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
1655 n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
1656 // Update the cell state.
1657 UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
1658 /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
1659 forget_gate_scratch, cell_gate_scratch,
1660 /*use_cifg=*/true, quantized_cell_clip);
1661 // Calculate the output gate.
1662 CalculateLstmGateInteger8x8_8(
1663 input_ptr, input_zp, input_to_output_weight_ptr,
1664 effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1665 intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
1666 output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
1667 effective_recurrent_to_output_scale_a,
1668 effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
1669 intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
1670 layer_norm_output_scale_a, layer_norm_output_scale_b,
1671 output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1672 kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
1673 // Update the output state.
1674 CalculateLstmOutputInteger8x8_8(
1675 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1676 projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
1677 projection_bias_ptr, output_state_zp, quantized_proj_clip,
1678 output_state_ptr, scratch2);
1679 // Copy output state to the output. Note that unlike float or hybrid, output
1680 // is always contigous.
1681 std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1682 }
1683
1684 } // namespace
1685
1686 // LINT.IfChange
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output)1687 TfLiteStatus EvalFloat(
1688 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1689 const TfLiteTensor* input_to_forget_weights,
1690 const TfLiteTensor* input_to_cell_weights,
1691 const TfLiteTensor* input_to_output_weights,
1692 const TfLiteTensor* recurrent_to_input_weights,
1693 const TfLiteTensor* recurrent_to_forget_weights,
1694 const TfLiteTensor* recurrent_to_cell_weights,
1695 const TfLiteTensor* recurrent_to_output_weights,
1696 const TfLiteTensor* cell_to_input_weights,
1697 const TfLiteTensor* cell_to_forget_weights,
1698 const TfLiteTensor* cell_to_output_weights,
1699 const TfLiteTensor* input_layer_norm_coefficients,
1700 const TfLiteTensor* forget_layer_norm_coefficients,
1701 const TfLiteTensor* cell_layer_norm_coefficients,
1702 const TfLiteTensor* output_layer_norm_coefficients,
1703 const TfLiteTensor* aux_input,
1704 const TfLiteTensor* aux_input_to_input_weights,
1705 const TfLiteTensor* aux_input_to_forget_weights,
1706 const TfLiteTensor* aux_input_to_cell_weights,
1707 const TfLiteTensor* aux_input_to_output_weights,
1708 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1709 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1710 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1711 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1712 int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
1713 TfLiteTensor* cell_state, TfLiteTensor* output) {
1714 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1715 int max_time, n_batch;
1716 if (input->dims->size == 3) {
1717 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1718 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1719 } else {
1720 max_time = 1;
1721 n_batch = input->dims->data[0];
1722 }
1723 const int n_input = input->dims->data[input->dims->size - 1];
1724 const int aux_input_size =
1725 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1726
1727 // n_cell and n_output will be the same size when there is no projection.
1728 const int n_cell = input_to_output_weights->dims->data[0];
1729 const int n_output = recurrent_to_output_weights->dims->data[1];
1730
1731 // Since we have already checked that weights are all there or none, we can
1732 // check the existence of only one to the get the condition.
1733 const bool use_cifg = (input_to_input_weights == nullptr);
1734
1735 // Index the scratch buffers pointers to the global scratch buffer.
1736 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1737 float* input_gate_scratch = nullptr;
1738 float* cell_gate_scratch = nullptr;
1739 float* forget_gate_scratch = nullptr;
1740 float* output_gate_scratch = nullptr;
1741 if (use_cifg) {
1742 cell_gate_scratch = scratch_buffer_ptr;
1743 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1744 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1745 } else {
1746 input_gate_scratch = scratch_buffer_ptr;
1747 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1748 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1749 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1750 }
1751
1752 const int output_batch_leading_dim =
1753 output->dims->data[output->dims->size - 1];
1754 if (time_major) {
1755 // Loop through the sequence.
1756 const int input_step = n_batch * n_input;
1757 const int output_step = n_batch * output_batch_leading_dim;
1758 for (int t = 0; t < max_time; t++) {
1759 // If this is the forward_sequence, step forward, otherwise step
1760 // backwards.
1761 const int t_rel = forward_sequence ? t : max_time - t - 1;
1762 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1763 const float* aux_input_ptr = nullptr;
1764 if (aux_input) {
1765 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1766 }
1767 float* output_ptr =
1768 GetTensorData<float>(output) + t_rel * output_step + output_offset;
1769
1770 LstmStepFloat(
1771 input_ptr, GetTensorData<float>(input_to_input_weights),
1772 GetTensorData<float>(input_to_forget_weights),
1773 GetTensorData<float>(input_to_cell_weights),
1774 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1775 GetTensorData<float>(aux_input_to_input_weights),
1776 GetTensorData<float>(aux_input_to_forget_weights),
1777 GetTensorData<float>(aux_input_to_cell_weights),
1778 GetTensorData<float>(aux_input_to_output_weights),
1779 GetTensorData<float>(recurrent_to_input_weights),
1780 GetTensorData<float>(recurrent_to_forget_weights),
1781 GetTensorData<float>(recurrent_to_cell_weights),
1782 GetTensorData<float>(recurrent_to_output_weights),
1783 GetTensorData<float>(cell_to_input_weights),
1784 GetTensorData<float>(cell_to_forget_weights),
1785 GetTensorData<float>(cell_to_output_weights),
1786 GetTensorData<float>(input_layer_norm_coefficients),
1787 GetTensorData<float>(forget_layer_norm_coefficients),
1788 GetTensorData<float>(cell_layer_norm_coefficients),
1789 GetTensorData<float>(output_layer_norm_coefficients),
1790 GetTensorData<float>(input_gate_bias),
1791 GetTensorData<float>(forget_gate_bias),
1792 GetTensorData<float>(cell_gate_bias),
1793 GetTensorData<float>(output_gate_bias),
1794 GetTensorData<float>(projection_weights),
1795 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
1796 n_input, aux_input_size, n_output, output_batch_leading_dim,
1797 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
1798 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
1799 output_gate_scratch, output_ptr);
1800 }
1801 } else {
1802 for (int b = 0; b < n_batch; b++) {
1803 const int input_step = n_input;
1804 const int output_step = output_batch_leading_dim;
1805 for (int t = 0; t < max_time; t++) {
1806 // If this is the forward_sequence, step forward, otherwise step
1807 // backwards.
1808 const int t_rel = forward_sequence ? t : max_time - t - 1;
1809 const int time_offset = b * max_time + t_rel;
1810 const float* input_ptr =
1811 GetTensorData<float>(input) + time_offset * input_step;
1812 const float* aux_input_ptr = nullptr;
1813 if (aux_input) {
1814 aux_input_ptr =
1815 GetTensorData<float>(aux_input) + time_offset * input_step;
1816 }
1817 float* output_ptr = GetTensorData<float>(output) +
1818 time_offset * output_step + output_offset;
1819
1820 // Offset the {output,cell}_state pointers to the right batch.
1821 float* output_state_ptr =
1822 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
1823 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
1824 // Offset the scratch pointers to the right batch.
1825 float* input_gate_scratch_ptr =
1826 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1827 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1828 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
1829 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1830
1831 LstmStepFloat(
1832 input_ptr, GetTensorData<float>(input_to_input_weights),
1833 GetTensorData<float>(input_to_forget_weights),
1834 GetTensorData<float>(input_to_cell_weights),
1835 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1836 GetTensorData<float>(aux_input_to_input_weights),
1837 GetTensorData<float>(aux_input_to_forget_weights),
1838 GetTensorData<float>(aux_input_to_cell_weights),
1839 GetTensorData<float>(aux_input_to_output_weights),
1840 GetTensorData<float>(recurrent_to_input_weights),
1841 GetTensorData<float>(recurrent_to_forget_weights),
1842 GetTensorData<float>(recurrent_to_cell_weights),
1843 GetTensorData<float>(recurrent_to_output_weights),
1844 GetTensorData<float>(cell_to_input_weights),
1845 GetTensorData<float>(cell_to_forget_weights),
1846 GetTensorData<float>(cell_to_output_weights),
1847 GetTensorData<float>(input_layer_norm_coefficients),
1848 GetTensorData<float>(forget_layer_norm_coefficients),
1849 GetTensorData<float>(cell_layer_norm_coefficients),
1850 GetTensorData<float>(output_layer_norm_coefficients),
1851 GetTensorData<float>(input_gate_bias),
1852 GetTensorData<float>(forget_gate_bias),
1853 GetTensorData<float>(cell_gate_bias),
1854 GetTensorData<float>(output_gate_bias),
1855 GetTensorData<float>(projection_weights),
1856 GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
1857 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1858 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1859 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
1860 output_gate_scratch_ptr, output_ptr);
1861 }
1862 }
1863 }
1864 return kTfLiteOk;
1865 }
1866 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1867
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_input_weights_ledger,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_forget_weights_ledger,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_cell_weights_ledger,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * input_to_output_weights_ledger,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_input_weights_ledger,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_forget_weights_ledger,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_cell_weights_ledger,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * recurrent_to_output_weights_ledger,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_weights_ledger,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * input_sf,TfLiteTensor * aux_input_sf,TfLiteTensor * output_state_sf,TfLiteTensor * prod_scaling_factors,TfLiteTensor * recovered_cell_weights,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * output_state_quantized,TfLiteTensor * cell_state_quantized,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output_scratch_buffer,TfLiteTensor * output,TfLiteTensor * input_zp,TfLiteTensor * aux_input_zp,TfLiteTensor * output_state_zp,TfLiteTensor * row_sums,int row_sums_size,bool * compute_row_sums,CpuBackendContext * context)1868 TfLiteStatus EvalHybrid(
1869 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1870 const TfLiteTensor* input_to_input_weights_ledger,
1871 const TfLiteTensor* input_to_forget_weights,
1872 const TfLiteTensor* input_to_forget_weights_ledger,
1873 const TfLiteTensor* input_to_cell_weights,
1874 const TfLiteTensor* input_to_cell_weights_ledger,
1875 const TfLiteTensor* input_to_output_weights,
1876 const TfLiteTensor* input_to_output_weights_ledger,
1877 const TfLiteTensor* recurrent_to_input_weights,
1878 const TfLiteTensor* recurrent_to_input_weights_ledger,
1879 const TfLiteTensor* recurrent_to_forget_weights,
1880 const TfLiteTensor* recurrent_to_forget_weights_ledger,
1881 const TfLiteTensor* recurrent_to_cell_weights,
1882 const TfLiteTensor* recurrent_to_cell_weights_ledger,
1883 const TfLiteTensor* recurrent_to_output_weights,
1884 const TfLiteTensor* recurrent_to_output_weights_ledger,
1885 const TfLiteTensor* cell_to_input_weights,
1886 const TfLiteTensor* cell_to_forget_weights,
1887 const TfLiteTensor* cell_to_output_weights,
1888 const TfLiteTensor* input_layer_norm_coefficients,
1889 const TfLiteTensor* forget_layer_norm_coefficients,
1890 const TfLiteTensor* cell_layer_norm_coefficients,
1891 const TfLiteTensor* output_layer_norm_coefficients,
1892 const TfLiteTensor* aux_input,
1893 const TfLiteTensor* aux_input_to_input_weights,
1894 const TfLiteTensor* aux_input_to_forget_weights,
1895 const TfLiteTensor* aux_input_to_cell_weights,
1896 const TfLiteTensor* aux_input_to_output_weights,
1897 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1898 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1899 const TfLiteTensor* projection_weights,
1900 const TfLiteTensor* projection_weights_ledger,
1901 const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
1902 bool forward_sequence, bool time_major, int output_offset,
1903 TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
1904 TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
1905 TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
1906 TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
1907 TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
1908 TfLiteTensor* output_state, TfLiteTensor* cell_state,
1909 TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
1910 TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
1911 TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
1912 bool* compute_row_sums, CpuBackendContext* context) {
1913 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1914 const int n_input = input->dims->data[input->dims->size - 1];
1915 int max_time, n_batch;
1916 if (input->dims->size == 2) {
1917 max_time = 1;
1918 n_batch = input->dims->data[0];
1919 } else {
1920 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1921 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1922 }
1923 const int aux_input_size =
1924 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1925 // n_cell and n_output will be the same size when there is no projection.
1926 const int n_cell = input_to_output_weights->dims->data[0];
1927 const int n_output = recurrent_to_output_weights->dims->data[1];
1928
1929 // Since we have already checked that weights are all there or none, we can
1930 // check the existence of only one to get the condition.
1931 const bool use_cifg = (input_to_input_weights == nullptr);
1932
1933 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1934 float* input_gate_scratch = nullptr;
1935 float* cell_gate_scratch = nullptr;
1936 float* forget_gate_scratch = nullptr;
1937 float* output_gate_scratch = nullptr;
1938 if (use_cifg) {
1939 cell_gate_scratch = scratch_buffer_ptr;
1940 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1941 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1942 } else {
1943 input_gate_scratch = scratch_buffer_ptr;
1944 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1945 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1946 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1947 }
1948
1949 const int output_batch_leading_dim =
1950 output->dims->data[output->dims->size - 1];
1951
1952 int32_t* input_zp_ptr = nullptr;
1953 int32_t* aux_input_zp_ptr = nullptr;
1954 int32_t* output_state_zp_ptr = nullptr;
1955 int32_t* row_sums_ptr = nullptr;
1956 if (params->asymmetric_quantize_inputs) {
1957 input_zp_ptr = GetTensorData<int32_t>(input_zp);
1958 aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
1959 output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
1960 row_sums_ptr = GetTensorData<int32_t>(row_sums);
1961 }
1962
1963 if (time_major) {
1964 // Feed the sequence into the LSTM step-by-step.
1965 const int input_step = n_batch * n_input;
1966 const int output_step = n_batch * output_batch_leading_dim;
1967 for (int t = 0; t < max_time; t++) {
1968 // If this is the forward_sequence, step forward, otherwise step
1969 // backwards.
1970 const int t_rel = forward_sequence ? t : max_time - t - 1;
1971 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1972 const float* aux_input_ptr = nullptr;
1973 if (aux_input) {
1974 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1975 }
1976 float* output_ptr =
1977 GetTensorData<float>(output) + t_rel * output_step + output_offset;
1978 LstmStepHybrid(
1979 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
1980 GetTensorData<uint8_t>(input_to_input_weights_ledger),
1981 GetTensorScale(input_to_input_weights),
1982 GetTensorData<int8_t>(input_to_forget_weights),
1983 GetTensorData<uint8_t>(input_to_forget_weights_ledger),
1984 GetTensorScale(input_to_forget_weights),
1985 GetTensorData<int8_t>(input_to_cell_weights),
1986 GetTensorData<uint8_t>(input_to_cell_weights_ledger),
1987 GetTensorScale(input_to_cell_weights),
1988 GetTensorData<int8_t>(input_to_output_weights),
1989 GetTensorData<uint8_t>(input_to_output_weights_ledger),
1990 GetTensorScale(input_to_output_weights), aux_input_ptr,
1991 GetTensorData<int8_t>(aux_input_to_input_weights),
1992 GetTensorScale(aux_input_to_input_weights),
1993 GetTensorData<int8_t>(aux_input_to_forget_weights),
1994 GetTensorScale(aux_input_to_forget_weights),
1995 GetTensorData<int8_t>(aux_input_to_cell_weights),
1996 GetTensorScale(aux_input_to_cell_weights),
1997 GetTensorData<int8_t>(aux_input_to_output_weights),
1998 GetTensorScale(aux_input_to_output_weights),
1999 GetTensorData<int8_t>(recurrent_to_input_weights),
2000 GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2001 GetTensorScale(recurrent_to_input_weights),
2002 GetTensorData<int8_t>(recurrent_to_forget_weights),
2003 GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2004 GetTensorScale(recurrent_to_forget_weights),
2005 GetTensorData<int8_t>(recurrent_to_cell_weights),
2006 GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2007 GetTensorScale(recurrent_to_cell_weights),
2008 GetTensorData<int8_t>(recurrent_to_output_weights),
2009 GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2010 GetTensorScale(recurrent_to_output_weights),
2011 GetTensorData<int8_t>(cell_to_input_weights),
2012 GetTensorScale(cell_to_input_weights),
2013 GetTensorData<int8_t>(cell_to_forget_weights),
2014 GetTensorScale(cell_to_forget_weights),
2015 GetTensorData<int8_t>(cell_to_output_weights),
2016 GetTensorScale(cell_to_output_weights),
2017 GetTensorData<float>(input_layer_norm_coefficients),
2018 GetTensorData<float>(forget_layer_norm_coefficients),
2019 GetTensorData<float>(cell_layer_norm_coefficients),
2020 GetTensorData<float>(output_layer_norm_coefficients),
2021 GetTensorData<float>(input_gate_bias),
2022 GetTensorData<float>(forget_gate_bias),
2023 GetTensorData<float>(cell_gate_bias),
2024 GetTensorData<float>(output_gate_bias),
2025 GetTensorData<int8_t>(projection_weights),
2026 GetTensorData<uint8_t>(projection_weights_ledger),
2027 GetTensorScale(projection_weights),
2028 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
2029 n_input, aux_input_size, n_output, output_batch_leading_dim,
2030 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
2031 output_gate_scratch, GetTensorData<float>(input_sf),
2032 GetTensorData<float>(aux_input_sf),
2033 GetTensorData<float>(output_state_sf),
2034 GetTensorData<float>(prod_scaling_factors),
2035 GetTensorData<float>(recovered_cell_weights),
2036 GetTensorData<int8_t>(input_quantized),
2037 GetTensorData<int8_t>(aux_input_quantized),
2038 GetTensorData<int8_t>(output_state_quantized),
2039 GetTensorData<int8_t>(cell_state_quantized),
2040 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
2041 GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
2042 input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
2043 row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
2044 context);
2045 }
2046 } else {
2047 for (int b = 0; b < n_batch; b++) {
2048 const int input_step = n_input;
2049 const int output_step = output_batch_leading_dim;
2050 for (int t = 0; t < max_time; t++) {
2051 // If this is the forward_sequence, step forward, otherwise step
2052 // backwards.
2053 const int t_rel = forward_sequence ? t : max_time - t - 1;
2054 const int time_offset = b * max_time + t_rel;
2055 const float* input_ptr =
2056 GetTensorData<float>(input) + time_offset * input_step;
2057 const float* aux_input_ptr = nullptr;
2058 if (aux_input) {
2059 aux_input_ptr =
2060 GetTensorData<float>(aux_input) + time_offset * input_step;
2061 }
2062 float* output_ptr = GetTensorData<float>(output) +
2063 time_offset * output_step + output_offset;
2064
2065 // Offset the {output,cell}_state pointers to the right batch.
2066 float* output_state_ptr =
2067 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
2068 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
2069 // Offset the scratch pointers to the right batch.
2070 float* input_gate_scratch_ptr =
2071 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
2072 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
2073 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
2074 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
2075
2076 LstmStepHybrid(
2077 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2078 GetTensorData<uint8_t>(input_to_input_weights_ledger),
2079 GetTensorScale(input_to_input_weights),
2080 GetTensorData<int8_t>(input_to_forget_weights),
2081 GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2082 GetTensorScale(input_to_forget_weights),
2083 GetTensorData<int8_t>(input_to_cell_weights),
2084 GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2085 GetTensorScale(input_to_cell_weights),
2086 GetTensorData<int8_t>(input_to_output_weights),
2087 GetTensorData<uint8_t>(input_to_output_weights_ledger),
2088 GetTensorScale(input_to_output_weights), aux_input_ptr,
2089 GetTensorData<int8_t>(aux_input_to_input_weights),
2090 GetTensorScale(aux_input_to_input_weights),
2091 GetTensorData<int8_t>(aux_input_to_forget_weights),
2092 GetTensorScale(aux_input_to_forget_weights),
2093 GetTensorData<int8_t>(aux_input_to_cell_weights),
2094 GetTensorScale(aux_input_to_cell_weights),
2095 GetTensorData<int8_t>(aux_input_to_output_weights),
2096 GetTensorScale(aux_input_to_output_weights),
2097 GetTensorData<int8_t>(recurrent_to_input_weights),
2098 GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2099 GetTensorScale(recurrent_to_input_weights),
2100 GetTensorData<int8_t>(recurrent_to_forget_weights),
2101 GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2102 GetTensorScale(recurrent_to_forget_weights),
2103 GetTensorData<int8_t>(recurrent_to_cell_weights),
2104 GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2105 GetTensorScale(recurrent_to_cell_weights),
2106 GetTensorData<int8_t>(recurrent_to_output_weights),
2107 GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2108 GetTensorScale(recurrent_to_output_weights),
2109 GetTensorData<int8_t>(cell_to_input_weights),
2110 GetTensorScale(cell_to_input_weights),
2111 GetTensorData<int8_t>(cell_to_forget_weights),
2112 GetTensorScale(cell_to_forget_weights),
2113 GetTensorData<int8_t>(cell_to_output_weights),
2114 GetTensorScale(cell_to_output_weights),
2115 GetTensorData<float>(input_layer_norm_coefficients),
2116 GetTensorData<float>(forget_layer_norm_coefficients),
2117 GetTensorData<float>(cell_layer_norm_coefficients),
2118 GetTensorData<float>(output_layer_norm_coefficients),
2119 GetTensorData<float>(input_gate_bias),
2120 GetTensorData<float>(forget_gate_bias),
2121 GetTensorData<float>(cell_gate_bias),
2122 GetTensorData<float>(output_gate_bias),
2123 GetTensorData<int8_t>(projection_weights),
2124 GetTensorData<uint8_t>(projection_weights_ledger),
2125 GetTensorScale(projection_weights),
2126 GetTensorData<float>(projection_bias), params,
2127 /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
2128 output_batch_leading_dim, input_gate_scratch_ptr,
2129 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
2130 output_gate_scratch_ptr, GetTensorData<float>(input_sf),
2131 GetTensorData<float>(aux_input_sf),
2132 GetTensorData<float>(output_state_sf),
2133 GetTensorData<float>(prod_scaling_factors),
2134 GetTensorData<float>(recovered_cell_weights),
2135 GetTensorData<int8_t>(input_quantized),
2136 GetTensorData<int8_t>(aux_input_quantized),
2137 GetTensorData<int8_t>(output_state_quantized),
2138 GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
2139 cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
2140 output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
2141 row_sums_ptr, row_sums_size, compute_row_sums,
2142 params->asymmetric_quantize_inputs, context);
2143 }
2144 }
2145 }
2146
2147 return kTfLiteOk;
2148 }
2149
EvalInteger8x8_16(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,CpuBackendContext * context)2150 TfLiteStatus EvalInteger8x8_16(
2151 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2152 const TfLiteTensor* input_to_forget_weights,
2153 const TfLiteTensor* input_to_cell_weights,
2154 const TfLiteTensor* input_to_output_weights,
2155 const TfLiteTensor* recurrent_to_input_weights,
2156 const TfLiteTensor* recurrent_to_forget_weights,
2157 const TfLiteTensor* recurrent_to_cell_weights,
2158 const TfLiteTensor* recurrent_to_output_weights,
2159 const TfLiteTensor* cell_to_input_weights,
2160 const TfLiteTensor* cell_to_forget_weights,
2161 const TfLiteTensor* cell_to_output_weights,
2162 const TfLiteTensor* input_layer_norm_coefficients,
2163 const TfLiteTensor* forget_layer_norm_coefficients,
2164 const TfLiteTensor* cell_layer_norm_coefficients,
2165 const TfLiteTensor* output_layer_norm_coefficients,
2166 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2167 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2168 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2169 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
2170 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2171 TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
2172 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2173 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2174 CpuBackendContext* context) {
2175 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2176 const int n_input = input->dims->data[input->dims->size - 1];
2177 int max_time, n_batch;
2178 if (input->dims->size == 2) {
2179 max_time = 1;
2180 n_batch = input->dims->data[0];
2181 } else {
2182 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
2183 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
2184 }
2185
2186 // n_cell and n_output will be the same size when there is no projection.
2187 const int n_cell = input_to_output_weights->dims->data[0];
2188 const int n_output = recurrent_to_output_weights->dims->data[1];
2189
2190 // Activation zero point
2191 int output_state_zp = output_state->params.zero_point;
2192
2193 // Get params for time/batch/sequence.
2194 const int output_batch_leading_dim =
2195 output->dims->data[output->dims->size - 1];
2196
2197 if (time_major) {
2198 const int input_step = n_batch * n_input;
2199 const int output_step = n_batch * output_batch_leading_dim;
2200 for (int t = 0; t < max_time; t++) {
2201 const int t_rel = t;
2202 int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2203 const int8_t* input_ptr =
2204 GetTensorData<int8_t>(input) + t_rel * input_step;
2205 LstmStepInteger8x8_16(
2206 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2207 integer_lstm_param->effective_input_to_input_scale_a,
2208 integer_lstm_param->effective_input_to_input_scale_b,
2209 GetTensorData<int8_t>(input_to_forget_weights),
2210 integer_lstm_param->effective_input_to_forget_scale_a,
2211 integer_lstm_param->effective_input_to_forget_scale_b,
2212 GetTensorData<int8_t>(input_to_cell_weights),
2213 integer_lstm_param->effective_input_to_cell_scale_a,
2214 integer_lstm_param->effective_input_to_cell_scale_b,
2215 GetTensorData<int8_t>(input_to_output_weights),
2216 integer_lstm_param->effective_input_to_output_scale_a,
2217 integer_lstm_param->effective_input_to_output_scale_b,
2218 GetTensorData<int8_t>(recurrent_to_input_weights),
2219 integer_lstm_param->effective_recurrent_to_input_scale_a,
2220 integer_lstm_param->effective_recurrent_to_input_scale_b,
2221 GetTensorData<int8_t>(recurrent_to_forget_weights),
2222 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2223 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2224 GetTensorData<int8_t>(recurrent_to_cell_weights),
2225 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2226 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2227 GetTensorData<int8_t>(recurrent_to_output_weights),
2228 integer_lstm_param->effective_recurrent_to_output_scale_a,
2229 integer_lstm_param->effective_recurrent_to_output_scale_b,
2230 GetTensorData<int16_t>(cell_to_input_weights),
2231 integer_lstm_param->effective_cell_to_input_scale_a,
2232 integer_lstm_param->effective_cell_to_input_scale_b,
2233 GetTensorData<int16_t>(cell_to_forget_weights),
2234 integer_lstm_param->effective_cell_to_forget_scale_a,
2235 integer_lstm_param->effective_cell_to_forget_scale_b,
2236 GetTensorData<int16_t>(cell_to_output_weights),
2237 integer_lstm_param->effective_cell_to_output_scale_a,
2238 integer_lstm_param->effective_cell_to_output_scale_b,
2239 GetTensorData<int8_t>(projection_weights),
2240 integer_lstm_param->effective_proj_scale_a,
2241 integer_lstm_param->effective_proj_scale_b,
2242 integer_lstm_param->hidden_zp,
2243 integer_lstm_param->effective_hidden_scale_a,
2244 integer_lstm_param->effective_hidden_scale_b,
2245 GetTensorData<int16_t>(input_layer_norm_coefficients),
2246 integer_lstm_param->layer_norm_input_scale_a,
2247 integer_lstm_param->layer_norm_input_scale_b,
2248 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2249 integer_lstm_param->layer_norm_forget_scale_a,
2250 integer_lstm_param->layer_norm_forget_scale_b,
2251 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2252 integer_lstm_param->layer_norm_cell_scale_a,
2253 integer_lstm_param->layer_norm_cell_scale_b,
2254 GetTensorData<int16_t>(output_layer_norm_coefficients),
2255 integer_lstm_param->layer_norm_output_scale_a,
2256 integer_lstm_param->layer_norm_output_scale_b,
2257 GetTensorData<int32_t>(input_gate_bias),
2258 GetTensorData<int32_t>(forget_gate_bias),
2259 GetTensorData<int32_t>(cell_gate_bias),
2260 GetTensorData<int32_t>(output_gate_bias),
2261 integer_lstm_param->quantized_cell_clip,
2262 integer_lstm_param->quantized_proj_clip,
2263 integer_lstm_param->cell_scale,
2264 integer_lstm_param->input_variance_guard,
2265 integer_lstm_param->forget_variance_guard,
2266 integer_lstm_param->cell_variance_guard,
2267 integer_lstm_param->output_variance_guard,
2268 integer_lstm_param->input_to_forget_effective_bias.get(),
2269 integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2270 integer_lstm_param->input_to_cell_effective_bias.get(),
2271 integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2272 integer_lstm_param->input_to_output_effective_bias.get(),
2273 integer_lstm_param->recurrent_to_output_effective_bias.get(),
2274 integer_lstm_param->input_to_input_effective_bias.get(),
2275 integer_lstm_param->recurrent_to_input_effective_bias.get(),
2276 integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
2277 n_input, n_output, GetTensorData<int8_t>(output_state),
2278 output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2279 GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
2280 GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2281 GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
2282 context);
2283 }
2284 } else {
2285 for (int b = 0; b < n_batch; b++) {
2286 const int input_step = n_input;
2287 const int output_step = output_batch_leading_dim;
2288 for (int t = 0; t < max_time; t++) {
2289 // If this is the forward_sequence, step forward, otherwise step
2290 // backwards.
2291 const int t_rel = forward_sequence ? t : max_time - t - 1;
2292 const int time_offset = b * max_time + t_rel;
2293 const int8_t* input_ptr =
2294 GetTensorData<int8_t>(input) + time_offset * input_step;
2295 int8_t* output_ptr =
2296 GetTensorData<int8_t>(output) + time_offset * output_step;
2297
2298 // Offset the {output,cell}_state pointers to the right batch.
2299 int8_t* output_state_ptr =
2300 GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim;
2301 int16_t* cell_state_ptr =
2302 GetTensorData<int16_t>(cell_state) + b * n_cell;
2303
2304 LstmStepInteger8x8_16(
2305 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2306 integer_lstm_param->effective_input_to_input_scale_a,
2307 integer_lstm_param->effective_input_to_input_scale_b,
2308 GetTensorData<int8_t>(input_to_forget_weights),
2309 integer_lstm_param->effective_input_to_forget_scale_a,
2310 integer_lstm_param->effective_input_to_forget_scale_b,
2311 GetTensorData<int8_t>(input_to_cell_weights),
2312 integer_lstm_param->effective_input_to_cell_scale_a,
2313 integer_lstm_param->effective_input_to_cell_scale_b,
2314 GetTensorData<int8_t>(input_to_output_weights),
2315 integer_lstm_param->effective_input_to_output_scale_a,
2316 integer_lstm_param->effective_input_to_output_scale_b,
2317 GetTensorData<int8_t>(recurrent_to_input_weights),
2318 integer_lstm_param->effective_recurrent_to_input_scale_a,
2319 integer_lstm_param->effective_recurrent_to_input_scale_b,
2320 GetTensorData<int8_t>(recurrent_to_forget_weights),
2321 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2322 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2323 GetTensorData<int8_t>(recurrent_to_cell_weights),
2324 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2325 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2326 GetTensorData<int8_t>(recurrent_to_output_weights),
2327 integer_lstm_param->effective_recurrent_to_output_scale_a,
2328 integer_lstm_param->effective_recurrent_to_output_scale_b,
2329 GetTensorData<int16_t>(cell_to_input_weights),
2330 integer_lstm_param->effective_cell_to_input_scale_a,
2331 integer_lstm_param->effective_cell_to_input_scale_b,
2332 GetTensorData<int16_t>(cell_to_forget_weights),
2333 integer_lstm_param->effective_cell_to_forget_scale_a,
2334 integer_lstm_param->effective_cell_to_forget_scale_b,
2335 GetTensorData<int16_t>(cell_to_output_weights),
2336 integer_lstm_param->effective_cell_to_output_scale_a,
2337 integer_lstm_param->effective_cell_to_output_scale_b,
2338 GetTensorData<int8_t>(projection_weights),
2339 integer_lstm_param->effective_proj_scale_a,
2340 integer_lstm_param->effective_proj_scale_b,
2341 integer_lstm_param->hidden_zp,
2342 integer_lstm_param->effective_hidden_scale_a,
2343 integer_lstm_param->effective_hidden_scale_b,
2344 GetTensorData<int16_t>(input_layer_norm_coefficients),
2345 integer_lstm_param->layer_norm_input_scale_a,
2346 integer_lstm_param->layer_norm_input_scale_b,
2347 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2348 integer_lstm_param->layer_norm_forget_scale_a,
2349 integer_lstm_param->layer_norm_forget_scale_b,
2350 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2351 integer_lstm_param->layer_norm_cell_scale_a,
2352 integer_lstm_param->layer_norm_cell_scale_b,
2353 GetTensorData<int16_t>(output_layer_norm_coefficients),
2354 integer_lstm_param->layer_norm_output_scale_a,
2355 integer_lstm_param->layer_norm_output_scale_b,
2356 GetTensorData<int32_t>(input_gate_bias),
2357 GetTensorData<int32_t>(forget_gate_bias),
2358 GetTensorData<int32_t>(cell_gate_bias),
2359 GetTensorData<int32_t>(output_gate_bias),
2360 integer_lstm_param->quantized_cell_clip,
2361 integer_lstm_param->quantized_proj_clip,
2362 integer_lstm_param->cell_scale,
2363 integer_lstm_param->input_variance_guard,
2364 integer_lstm_param->forget_variance_guard,
2365 integer_lstm_param->cell_variance_guard,
2366 integer_lstm_param->output_variance_guard,
2367 integer_lstm_param->input_to_forget_effective_bias.get(),
2368 integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2369 integer_lstm_param->input_to_cell_effective_bias.get(),
2370 integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2371 integer_lstm_param->input_to_output_effective_bias.get(),
2372 integer_lstm_param->recurrent_to_output_effective_bias.get(),
2373 integer_lstm_param->input_to_input_effective_bias.get(),
2374 integer_lstm_param->recurrent_to_input_effective_bias.get(),
2375 integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
2376 n_cell, n_input, n_output, output_state_ptr, output_state_zp,
2377 cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0),
2378 GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2),
2379 GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4),
2380 GetTensorData<int32_t>(scratch5), context);
2381 }
2382 }
2383 }
2384
2385 return kTfLiteOk;
2386 }
2387
EvalInteger8x8_8(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,TfLiteTensor * scratch6,TfLiteTensor * scratch7)2388 TfLiteStatus EvalInteger8x8_8(
2389 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2390 const TfLiteTensor* input_to_forget_weights,
2391 const TfLiteTensor* input_to_cell_weights,
2392 const TfLiteTensor* input_to_output_weights,
2393 const TfLiteTensor* recurrent_to_input_weights,
2394 const TfLiteTensor* recurrent_to_forget_weights,
2395 const TfLiteTensor* recurrent_to_cell_weights,
2396 const TfLiteTensor* recurrent_to_output_weights,
2397 const TfLiteTensor* cell_to_input_weights,
2398 const TfLiteTensor* cell_to_forget_weights,
2399 const TfLiteTensor* cell_to_output_weights,
2400 const TfLiteTensor* input_layer_norm_coefficients,
2401 const TfLiteTensor* forget_layer_norm_coefficients,
2402 const TfLiteTensor* cell_layer_norm_coefficients,
2403 const TfLiteTensor* output_layer_norm_coefficients,
2404 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2405 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2406 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2407 const TfLiteLSTMParams* params, TfLiteTensor* output_state,
2408 TfLiteTensor* cell_state, TfLiteTensor* output,
2409 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2410 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2411 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2412 TfLiteTensor* scratch6, TfLiteTensor* scratch7) {
2413 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2414 const int n_input = input->dims->data[input->dims->size - 1];
2415 int max_time, n_batch;
2416 if (input->dims->size == 2) {
2417 max_time = 1;
2418 n_batch = input->dims->data[0];
2419 } else {
2420 max_time = input->dims->data[0];
2421 n_batch = input->dims->data[1];
2422 }
2423
2424 // n_cell and n_output will be the same size when there is no projection.
2425 const int n_cell = input_to_output_weights->dims->data[0];
2426 const int n_output = recurrent_to_output_weights->dims->data[1];
2427
2428 const int32_t input_zp = input->params.zero_point;
2429 const int32_t output_state_zp = output_state->params.zero_point;
2430
2431 // Get params for time/batch/sequence.
2432 const int output_batch_leading_dim =
2433 output->dims->data[output->dims->size - 1];
2434 const int input_step = n_batch * n_input;
2435 const int output_step = n_batch * output_batch_leading_dim;
2436
2437 for (int t = 0; t < max_time; t++) {
2438 const int t_rel = t;
2439 int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2440 // Input can be int8 asymmetric or int16 symmetric.
2441 const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
2442 lstm_eval::LstmStepInteger8x8_8(
2443 input_ptr, input_zp,
2444
2445 GetTensorData<int8_t>(input_to_input_weights),
2446 integer_lstm_param->effective_input_to_input_scale_a,
2447 integer_lstm_param->effective_input_to_input_scale_b,
2448
2449 GetTensorData<int8_t>(input_to_forget_weights),
2450 integer_lstm_param->effective_input_to_forget_scale_a,
2451 integer_lstm_param->effective_input_to_forget_scale_b,
2452
2453 GetTensorData<int8_t>(input_to_cell_weights),
2454 integer_lstm_param->effective_input_to_cell_scale_a,
2455 integer_lstm_param->effective_input_to_cell_scale_b,
2456
2457 GetTensorData<int8_t>(input_to_output_weights),
2458 integer_lstm_param->effective_input_to_output_scale_a,
2459 integer_lstm_param->effective_input_to_output_scale_b,
2460
2461 GetTensorData<int8_t>(recurrent_to_input_weights),
2462 integer_lstm_param->effective_recurrent_to_input_scale_a,
2463 integer_lstm_param->effective_recurrent_to_input_scale_b,
2464
2465 GetTensorData<int8_t>(recurrent_to_forget_weights),
2466 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2467 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2468
2469 GetTensorData<int8_t>(recurrent_to_cell_weights),
2470 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2471 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2472
2473 GetTensorData<int8_t>(recurrent_to_output_weights),
2474 integer_lstm_param->effective_recurrent_to_output_scale_a,
2475 integer_lstm_param->effective_recurrent_to_output_scale_b,
2476
2477 GetTensorData<int8_t>(cell_to_input_weights),
2478 integer_lstm_param->effective_cell_to_input_scale_a,
2479 integer_lstm_param->effective_cell_to_input_scale_b,
2480
2481 GetTensorData<int8_t>(cell_to_forget_weights),
2482 integer_lstm_param->effective_cell_to_forget_scale_a,
2483 integer_lstm_param->effective_cell_to_forget_scale_b,
2484
2485 GetTensorData<int8_t>(cell_to_output_weights),
2486 integer_lstm_param->effective_cell_to_output_scale_a,
2487 integer_lstm_param->effective_cell_to_output_scale_b,
2488
2489 GetTensorData<int8_t>(projection_weights),
2490 integer_lstm_param->effective_proj_scale_a,
2491 integer_lstm_param->effective_proj_scale_b,
2492
2493 GetTensorData<int16_t>(input_layer_norm_coefficients),
2494 integer_lstm_param->layer_norm_input_scale_a,
2495 integer_lstm_param->layer_norm_input_scale_b,
2496
2497 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2498 integer_lstm_param->layer_norm_forget_scale_a,
2499 integer_lstm_param->layer_norm_forget_scale_b,
2500
2501 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2502 integer_lstm_param->layer_norm_cell_scale_a,
2503 integer_lstm_param->layer_norm_cell_scale_b,
2504
2505 GetTensorData<int16_t>(output_layer_norm_coefficients),
2506 integer_lstm_param->layer_norm_output_scale_a,
2507 integer_lstm_param->layer_norm_output_scale_b,
2508
2509 GetTensorData<int32_t>(input_gate_bias),
2510 GetTensorData<int32_t>(forget_gate_bias),
2511 GetTensorData<int32_t>(cell_gate_bias),
2512 GetTensorData<int32_t>(output_gate_bias),
2513 GetTensorData<int32_t>(projection_bias),
2514
2515 params, integer_lstm_param->intermediate_scale_a,
2516 integer_lstm_param->intermediate_scale_b,
2517 integer_lstm_param->intermediate_zp,
2518 integer_lstm_param->quantized_cell_clip,
2519 integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
2520 n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
2521 output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2522 GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
2523 GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2524 GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
2525 GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
2526 }
2527
2528 return kTfLiteOk;
2529 }
2530
2531 } // namespace lstm_eval
2532 } // namespace builtin
2533 } // namespace ops
2534 } // namespace tflite
2535