• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
16 #define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
17 
18 #include <cstdint>
19 #include <memory>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace lstm_eval {
29 
30 // Pamameters for integer LSTM.
31 // Consider split this into two Integer Parameters if more fields are added.
32 struct IntegerLstmParameter {
33   int32_t effective_input_to_input_scale_a;
34   int32_t effective_input_to_input_scale_b;
35   int32_t effective_recurrent_to_input_scale_a;
36   int32_t effective_recurrent_to_input_scale_b;
37   int32_t effective_cell_to_input_scale_a;
38   int32_t effective_cell_to_input_scale_b;
39   int32_t effective_input_to_forget_scale_a;
40   int32_t effective_input_to_forget_scale_b;
41   int32_t effective_recurrent_to_forget_scale_a;
42   int32_t effective_recurrent_to_forget_scale_b;
43   int32_t effective_cell_to_forget_scale_a;
44   int32_t effective_cell_to_forget_scale_b;
45   int32_t effective_input_to_cell_scale_a;
46   int32_t effective_input_to_cell_scale_b;
47   int32_t effective_recurrent_to_cell_scale_a;
48   int32_t effective_recurrent_to_cell_scale_b;
49   int32_t effective_input_to_output_scale_a;
50   int32_t effective_input_to_output_scale_b;
51   int32_t effective_recurrent_to_output_scale_a;
52   int32_t effective_recurrent_to_output_scale_b;
53   int32_t effective_cell_to_output_scale_a;
54   int32_t effective_cell_to_output_scale_b;
55   int32_t effective_proj_scale_a;
56   int32_t effective_proj_scale_b;
57   int32_t effective_hidden_scale_a;
58   int32_t effective_hidden_scale_b;
59   int32_t layer_norm_input_scale_a;
60   int32_t layer_norm_input_scale_b;
61   int32_t layer_norm_forget_scale_a;
62   int32_t layer_norm_forget_scale_b;
63   int32_t layer_norm_cell_scale_a;
64   int32_t layer_norm_cell_scale_b;
65   int32_t layer_norm_output_scale_a;
66   int32_t layer_norm_output_scale_b;
67   // Quantized clip value for cell and projection. Zero value means no clipping.
68   int16_t quantized_cell_clip;
69   int8_t quantized_proj_clip;
70   int32_t hidden_zp;
71   int32_t cell_scale;
72 
73   int32_t input_variance_guard;
74   int32_t forget_variance_guard;
75   int32_t cell_variance_guard;
76   int32_t output_variance_guard;
77 
78   // Pre-calculate bias + zero_point * weight.
79   // Unabled to use temporary tensors since those are used in Prepare() and
80   // scratch buffer is only allocated after Preapre().
81   std::unique_ptr<int32_t[]> input_to_forget_effective_bias;
82   std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias;
83   std::unique_ptr<int32_t[]> input_to_cell_effective_bias;
84   std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias;
85   std::unique_ptr<int32_t[]> input_to_output_effective_bias;
86   std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias;
87   std::unique_ptr<int32_t[]> input_to_input_effective_bias;
88   std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias;
89   std::unique_ptr<int32_t[]> projection_effective_bias;
90 
91   // Scale and zero point for intermediate tensors.
92   // Used only in the 8x8_8 case.
93   int32_t intermediate_scale_a[8];
94   int32_t intermediate_scale_b[8];
95   int32_t intermediate_zp[12];
96 };
97 
98 TfLiteStatus EvalFloat(
99     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
100     const TfLiteTensor* input_to_forget_weights,
101     const TfLiteTensor* input_to_cell_weights,
102     const TfLiteTensor* input_to_output_weights,
103     const TfLiteTensor* recurrent_to_input_weights,
104     const TfLiteTensor* recurrent_to_forget_weights,
105     const TfLiteTensor* recurrent_to_cell_weights,
106     const TfLiteTensor* recurrent_to_output_weights,
107     const TfLiteTensor* cell_to_input_weights,
108     const TfLiteTensor* cell_to_forget_weights,
109     const TfLiteTensor* cell_to_output_weights,
110     const TfLiteTensor* input_layer_norm_coefficients,
111     const TfLiteTensor* forget_layer_norm_coefficients,
112     const TfLiteTensor* cell_layer_norm_coefficients,
113     const TfLiteTensor* output_layer_norm_coefficients,
114     const TfLiteTensor* aux_input,
115     const TfLiteTensor* aux_input_to_input_weights,
116     const TfLiteTensor* aux_input_to_forget_weights,
117     const TfLiteTensor* aux_input_to_cell_weights,
118     const TfLiteTensor* aux_input_to_output_weights,
119     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
120     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
121     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
122     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
123     int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
124     TfLiteTensor* cell_state, TfLiteTensor* output);
125 
126 TfLiteStatus EvalHybrid(
127     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
128     const TfLiteTensor* input_to_input_weights_ledger,
129     const TfLiteTensor* input_to_forget_weights,
130     const TfLiteTensor* input_to_forget_weights_ledger,
131     const TfLiteTensor* input_to_cell_weights,
132     const TfLiteTensor* input_to_cell_weights_ledger,
133     const TfLiteTensor* input_to_output_weights,
134     const TfLiteTensor* input_to_output_weights_ledger,
135     const TfLiteTensor* recurrent_to_input_weights,
136     const TfLiteTensor* recurrent_to_input_weights_ledger,
137     const TfLiteTensor* recurrent_to_forget_weights,
138     const TfLiteTensor* recurrent_to_forget_weights_ledger,
139     const TfLiteTensor* recurrent_to_cell_weights,
140     const TfLiteTensor* recurrent_to_cell_weights_ledger,
141     const TfLiteTensor* recurrent_to_output_weights,
142     const TfLiteTensor* recurrent_to_output_weights_ledger,
143     const TfLiteTensor* cell_to_input_weights,
144     const TfLiteTensor* cell_to_forget_weights,
145     const TfLiteTensor* cell_to_output_weights,
146     const TfLiteTensor* input_layer_norm_coefficients,
147     const TfLiteTensor* forget_layer_norm_coefficients,
148     const TfLiteTensor* cell_layer_norm_coefficients,
149     const TfLiteTensor* output_layer_norm_coefficients,
150     const TfLiteTensor* aux_input,
151     const TfLiteTensor* aux_input_to_input_weights,
152     const TfLiteTensor* aux_input_to_forget_weights,
153     const TfLiteTensor* aux_input_to_cell_weights,
154     const TfLiteTensor* aux_input_to_output_weights,
155     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
156     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
157     const TfLiteTensor* projection_weights,
158     const TfLiteTensor* projection_weights_ledger,
159     const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
160     bool forward_sequence, bool time_major, int output_offset,
161     TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
162     TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
163     TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
164     TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
165     TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
166     TfLiteTensor* output_state, TfLiteTensor* cell_state,
167     TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
168     TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
169     TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
170     bool* compute_row_sums, CpuBackendContext* context);
171 
172 TfLiteStatus EvalInteger8x8_16(
173     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
174     const TfLiteTensor* input_to_forget_weights,
175     const TfLiteTensor* input_to_cell_weights,
176     const TfLiteTensor* input_to_output_weights,
177     const TfLiteTensor* recurrent_to_input_weights,
178     const TfLiteTensor* recurrent_to_forget_weights,
179     const TfLiteTensor* recurrent_to_cell_weights,
180     const TfLiteTensor* recurrent_to_output_weights,
181     const TfLiteTensor* cell_to_input_weights,
182     const TfLiteTensor* cell_to_forget_weights,
183     const TfLiteTensor* cell_to_output_weights,
184     const TfLiteTensor* input_layer_norm_coefficients,
185     const TfLiteTensor* forget_layer_norm_coefficients,
186     const TfLiteTensor* cell_layer_norm_coefficients,
187     const TfLiteTensor* output_layer_norm_coefficients,
188     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
189     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
190     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
191     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
192     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
193     TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
194     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
195     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
196     CpuBackendContext* context);
197 
198 TfLiteStatus EvalInteger8x8_8(
199     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
200     const TfLiteTensor* input_to_forget_weights,
201     const TfLiteTensor* input_to_cell_weights,
202     const TfLiteTensor* input_to_output_weights,
203     const TfLiteTensor* recurrent_to_input_weights,
204     const TfLiteTensor* recurrent_to_forget_weights,
205     const TfLiteTensor* recurrent_to_cell_weights,
206     const TfLiteTensor* recurrent_to_output_weights,
207     const TfLiteTensor* cell_to_input_weights,
208     const TfLiteTensor* cell_to_forget_weights,
209     const TfLiteTensor* cell_to_output_weights,
210     const TfLiteTensor* input_layer_norm_coefficients,
211     const TfLiteTensor* forget_layer_norm_coefficients,
212     const TfLiteTensor* cell_layer_norm_coefficients,
213     const TfLiteTensor* output_layer_norm_coefficients,
214     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
215     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
216     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
217     const TfLiteLSTMParams* params, TfLiteTensor* output_state,
218     TfLiteTensor* cell_state, TfLiteTensor* output,
219     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
220     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
221     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
222     TfLiteTensor* scratch6, TfLiteTensor* scratch7);
223 
224 }  // namespace lstm_eval
225 }  // namespace builtin
226 }  // namespace ops
227 }  // namespace tflite
228 #endif  // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
229