1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "QuantizedLSTM.h"
18
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21
22 #include "Tracing.h"
23
24 #include "public/gemmlowp.h"
25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
26
27 namespace android {
28 namespace nn {
29
30 namespace {
31
32 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)33 inline T* GetBuffer(RunTimeOperandInfo* operand) {
34 return reinterpret_cast<T*>(operand->buffer);
35 }
36
37 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
39 return reinterpret_cast<const T*>(operand->buffer);
40 }
41
42 using tflite::Dims;
43
44 // The function below is taken from TF Lite implementation in order to decouple
45 // NN API from TF Lite dependency. Original function, with a description of its
46 // parameters and types can be found by this link:
47 // https://github.com/tensorflow/tensorflow/blob/0d697e5fc4c05c699eea0764364104ea500ccc68/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h#L1926
48 //
49 // clang-format off
50 template <int StateIntegerBits>
quantizedLstmStep(const uint8_t * input_data_uint8,const Dims<4> & input_dims,const uint8_t * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8_t * weights_data_uint8,const Dims<4> & weights_dims,const int32_t * bias_data_int32,const Dims<4> & bias_dims,const int16_t * prevCellState_data_int16,const Dims<4> & prevCellState_dims,int16_t * output_state_data_int16,const Dims<4> & output_state_dims,uint8_t * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8_t * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16_t * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32_t weights_zero_point,int32_t accum_multiplier,int accum_shift)51 void quantizedLstmStep(const uint8_t* input_data_uint8, const Dims<4>& input_dims,
52 const uint8_t* prev_activ_data_uint8,
53 const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8,
54 const Dims<4>& weights_dims, const int32_t* bias_data_int32,
55 const Dims<4>& bias_dims, const int16_t* prevCellState_data_int16,
56 const Dims<4>& prevCellState_dims, int16_t* output_state_data_int16,
57 const Dims<4>& output_state_dims, uint8_t* output_activ_data_uint8,
58 const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8,
59 const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16,
60 const Dims<4>& activ_temp_dims, int32_t weights_zero_point,
61 int32_t accum_multiplier, int accum_shift) {
62 // Gather dimensions information, and perform consistency checks.
63 const int outer_size =
64 MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prevCellState_dims,
65 output_state_dims, output_activ_dims);
66 TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
67 TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
68 const int input_depth = ArraySize(input_dims, 0);
69 const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
70 const int total_input_depth = prev_activ_depth + input_depth;
71 TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
72 TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
73 1);
74 const int intern_activ_depth =
75 MatchingArraySize(weights_dims, 1, bias_dims, 0);
76 TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
77 const int output_depth =
78 MatchingArraySize(prevCellState_dims, 0, prev_activ_dims, 0,
79 output_state_dims, 0, output_activ_dims, 0);
80 TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
81 const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
82 const int fc_output_depth =
83 MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
84 const int fc_accum_depth = ArraySize(weights_dims, 0);
85 TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
86
87 // Depth-concatenate prev_activ and input data together.
88 uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
89 prev_activ_data_uint8};
90 Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
91 tflite::reference_ops::Concatenation<tflite::FusedActivationFunctionType::kNone, uint8_t>(
92 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
93 concat_temp_data_uint8, concat_temp_dims);
94
95 // Implementation of the fully connected node inside the LSTM cell.
96 // The operands are 8-bit integers, the accumulators are internally 32bit
97 // integers, and the output is 16-bit fixed-point with 3 integer bits so
98 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
99 // is explained in the function comment above.
100 for (int b = 0; b < fc_batches; ++b) {
101 for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
102 // Internal accumulation.
103 // Initialize accumulator with the bias-value.
104 int32_t accum = bias_data_int32[out_c];
105 // Accumulation loop.
106 for (int d = 0; d < fc_accum_depth; ++d) {
107 int16_t input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
108 int16_t weights_val =
109 weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
110 accum += input_val * weights_val;
111 }
112 // Down-scale the final int32 accumulator to the scale used by our
113 // (16-bit, using 3 integer bits) fixed-point format. The quantized
114 // multiplier and shift here have been pre-computed offline
115 // (e.g. by toco).
116 accum =
117 tflite::MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
118 // Saturate, cast to int16, and store to the temporary activations array.
119 accum = std::max(-32768, std::min(32767, accum));
120 activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
121 }
122 }
123
124 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
125 // and muls, all done in 16-bit fixed-point.
126 for (int b = 0; b < outer_size; ++b) {
127 for (int c = 0; c < output_depth; ++c) {
128 // Define the fixed-point data types that we will use here. All use
129 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
130 // They only differ by the number of integral vs. fractional bits,
131 // determining the range of values that they can represent.
132 //
133 // F0 uses 0 integer bits, range [-1, 1].
134 // This is the return type of math functions such as tanh, logistic,
135 // whose range is in [-1, 1].
136 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
137 // F3 uses 3 integer bits, range [-8, 8].
138 // This is the range of the previous fully-connected node's output,
139 // which is our input here.
140 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
141 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
142 // 2^StateIntegerBits]. It's used to represent the internal state, whose
143 // number of integer bits is currently dictated by the model. See comment
144 // on the StateIntegerBits template parameter above.
145 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
146 // Implementation of input gate, using fixed-point logistic function.
147 F3 input_gate_input = F3::FromRaw(
148 activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
149 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
150 // Implementation of input modulation gate, using fixed-point tanh
151 // function.
152 F3 input_modulation_gate_input = F3::FromRaw(
153 activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
154 F0 input_modulation_gate_output =
155 gemmlowp::tanh(input_modulation_gate_input);
156 // Implementation of forget gate, using fixed-point logistic function.
157 F3 forget_gate_input = F3::FromRaw(
158 activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
159 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
160 // Implementation of output gate, using fixed-point logistic function.
161 F3 output_gate_input = F3::FromRaw(
162 activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
163 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
164 // Implementation of internal multiplication nodes, still in fixed-point.
165 F0 input_times_input_modulation =
166 input_gate_output * input_modulation_gate_output;
167 FS prevCellState = FS::FromRaw(prevCellState_data_int16[b * output_depth + c]);
168 FS prevCellState_times_forget_state = forget_gate_output * prevCellState;
169 // Implementation of internal addition node, saturating.
170 FS new_state = gemmlowp::SaturatingAdd(
171 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
172 prevCellState_times_forget_state);
173 // Implementation of last internal Tanh node, still in fixed-point.
174 // Since a Tanh fixed-point implementation is specialized for a given
175 // number or integer bits, and each specialization can have a substantial
176 // code size, and we already used above a Tanh on an input with 3 integer
177 // bits, and per the table in the above function comment there is no
178 // significant accuracy to be lost by clamping to [-8, +8] for a
179 // 3-integer-bits representation, let us just do that. This helps people
180 // porting this to targets where code footprint must be minimized.
181 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
182 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
183 // Store the new internal state back to memory, as 16-bit integers.
184 // Note: here we store the original value with StateIntegerBits, not
185 // the rescaled 3-integer-bits value fed to tanh.
186 output_state_data_int16[b * output_depth + c] = new_state.raw();
187 // Down-scale the output activations to 8-bit integers, saturating,
188 // and store back to memory.
189 int16_t rescaled_output_activ =
190 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
191 int16_t clamped_output_activ =
192 std::max<int16_t>(-128, std::min<int16_t>(127, rescaled_output_activ));
193 output_activ_data_uint8[b * output_depth + c] =
194 128 + clamped_output_activ;
195 }
196 }
197 }
198 // clang-format on
199
200 // The function assigns a 2D matrix to a submatrix of the weights at a given row
201 // and column offsets.
assignWeightsSubmatrix(const RunTimeOperandInfo * submatrix,const int32_t offset_row,const int32_t offset_column,const std::vector<uint32_t> & weightsDims,uint8_t * weights)202 void assignWeightsSubmatrix(const RunTimeOperandInfo* submatrix, const int32_t offset_row,
203 const int32_t offset_column, const std::vector<uint32_t>& weightsDims,
204 uint8_t* weights) {
205 const uint8_t* submatrixValues = GetBuffer<uint8_t>(submatrix);
206 const std::vector<uint32_t> submatrixDims = submatrix->shape().dimensions;
207 for (uint32_t i = 0; i < submatrixDims[0] * submatrixDims[1]; ++i) {
208 const uint32_t row = i / submatrixDims[1];
209 const uint32_t column = i % submatrixDims[1];
210 weights[(row + offset_row) * weightsDims[1] + column + offset_column] = submatrixValues[i];
211 }
212 }
213
214 } // namespace
215
QuantizedLSTMCell(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)216 QuantizedLSTMCell::QuantizedLSTMCell(const Operation& operation,
217 std::vector<RunTimeOperandInfo>& operands) {
218 input_ = GetInput(operation, operands, kInputTensor);
219
220 inputToInputWeights_ = GetInput(operation, operands, kInputToInputWeightsTensor);
221 inputToForgetWeights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
222 inputToCellWeights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
223 inputToOutputWeights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
224
225 recurrentToInputWeights_ = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
226 recurrentToForgetWeights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
227 recurrentToCellWeights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
228 recurrentToOutputWeights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
229
230 inputGateBias_ = GetInput(operation, operands, kInputGateBiasTensor);
231 forgetGateBias_ = GetInput(operation, operands, kForgetGateBiasTensor);
232 cellGateBias_ = GetInput(operation, operands, kCellGateBiasTensor);
233 outputGateBias_ = GetInput(operation, operands, kOutputGateBiasTensor);
234
235 prevCellState_ = GetInput(operation, operands, kPrevCellStateTensor);
236 prevOutput_ = GetInput(operation, operands, kPrevOutputTensor);
237
238 cellStateOut_ = GetOutput(operation, operands, kCellStateOutTensor);
239 output_ = GetOutput(operation, operands, kOutputTensor);
240 }
241
prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * cellStateOutShape,Shape * outputShape)242 bool QuantizedLSTMCell::prepare(const Operation& operation,
243 std::vector<RunTimeOperandInfo>& operands, Shape* cellStateOutShape,
244 Shape* outputShape) {
245 auto input = GetInput(operation, operands, kInputTensor);
246 NN_RET_CHECK_EQ(NumDimensions(input), 2);
247 NN_RET_CHECK_EQ(input->scale, 1. / 128.0);
248 NN_RET_CHECK_EQ(input->zeroPoint, 128);
249 const uint32_t numBatches = SizeOfDimension(input, 0);
250 const uint32_t inputSize = SizeOfDimension(input, 1);
251
252 auto prevOutput = GetInput(operation, operands, kPrevOutputTensor);
253 NN_RET_CHECK_EQ(NumDimensions(prevOutput), 2);
254 NN_RET_CHECK_EQ(SizeOfDimension(prevOutput, 0), numBatches);
255 NN_RET_CHECK_EQ(prevOutput->scale, 1. / 128.0);
256 NN_RET_CHECK_EQ(prevOutput->zeroPoint, 128);
257 const uint32_t outputSize = SizeOfDimension(prevOutput, 1);
258
259 auto inputToInputWeights = GetInput(operation, operands, kInputToInputWeightsTensor);
260 const float weightsScale = inputToInputWeights->scale;
261 NN_RET_CHECK(weightsScale != 0);
262 const float weightsZeroPoint = inputToInputWeights->zeroPoint;
263
264 auto checkWeightsShape = [&](const RunTimeOperandInfo* weights, uint32_t columns) -> bool {
265 NN_RET_CHECK_EQ(NumDimensions(weights), 2);
266 NN_RET_CHECK_EQ(SizeOfDimension(weights, 0), outputSize);
267 NN_RET_CHECK_EQ(SizeOfDimension(weights, 1), columns);
268 NN_RET_CHECK_EQ(weights->scale, weightsScale);
269 NN_RET_CHECK_EQ(weights->zeroPoint, weightsZeroPoint);
270 return true;
271 };
272
273 auto inputToForgetWeights = GetInput(operation, operands, kInputToForgetWeightsTensor);
274 auto inputToCellWeights = GetInput(operation, operands, kInputToCellWeightsTensor);
275 auto inputToOutputWeights = GetInput(operation, operands, kInputToOutputWeightsTensor);
276 NN_RET_CHECK(checkWeightsShape(inputToInputWeights, inputSize));
277 NN_RET_CHECK(checkWeightsShape(inputToForgetWeights, inputSize));
278 NN_RET_CHECK(checkWeightsShape(inputToCellWeights, inputSize));
279 NN_RET_CHECK(checkWeightsShape(inputToOutputWeights, inputSize));
280
281 auto recurrentToInputWeights = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
282 auto recurrentToForgetWeights = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
283 auto recurrentToCellWeights = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
284 auto recurrentToOutputWeights = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
285 NN_RET_CHECK(checkWeightsShape(recurrentToInputWeights, outputSize));
286 NN_RET_CHECK(checkWeightsShape(recurrentToForgetWeights, outputSize));
287 NN_RET_CHECK(checkWeightsShape(recurrentToCellWeights, outputSize));
288 NN_RET_CHECK(checkWeightsShape(recurrentToOutputWeights, outputSize));
289
290 auto inputGateBias = GetInput(operation, operands, kInputGateBiasTensor);
291 const float biasScale = inputGateBias->scale;
292 NN_RET_CHECK_EQ(biasScale, weightsScale / 128.0);
293 const float biasZeroPoint = inputGateBias->zeroPoint;
294 NN_RET_CHECK_EQ(biasZeroPoint, 0);
295
296 auto checkBiasShape = [&](const RunTimeOperandInfo* bias) -> bool {
297 NN_RET_CHECK_EQ(NumDimensions(bias), 1);
298 NN_RET_CHECK_EQ(SizeOfDimension(bias, 0), outputSize);
299 NN_RET_CHECK_EQ(bias->scale, biasScale);
300 NN_RET_CHECK_EQ(bias->zeroPoint, biasZeroPoint);
301 return true;
302 };
303
304 auto forgetGateBias = GetInput(operation, operands, kForgetGateBiasTensor);
305 auto cellGateBias = GetInput(operation, operands, kCellGateBiasTensor);
306 auto outputGateBias = GetInput(operation, operands, kOutputGateBiasTensor);
307 NN_RET_CHECK(checkBiasShape(inputGateBias));
308 NN_RET_CHECK(checkBiasShape(forgetGateBias));
309 NN_RET_CHECK(checkBiasShape(cellGateBias));
310 NN_RET_CHECK(checkBiasShape(outputGateBias));
311
312 auto prevCellState = GetInput(operation, operands, kPrevCellStateTensor);
313 NN_CHECK_EQ(NumDimensions(prevCellState), 2);
314 NN_CHECK_EQ(SizeOfDimension(prevCellState, 0), numBatches);
315 NN_CHECK_EQ(SizeOfDimension(prevCellState, 1), outputSize);
316 NN_CHECK_EQ(prevCellState->zeroPoint, 0);
317 // Cell state range for quantized LSTM is a function of StateIntegerBits and
318 // can be calculated as:
319 // [-2^StateIntegerBits, 2^StateIntegerBits * 32767/32768].
320 // Therefore, for a fixed StateIntegerBits parameter, cell state scale is
321 // equal to 2^StateIntegerBits * 2^(-15) = 2^(StateIntegerBits - 15) and
322 // therefore:
323 // StateIntegerBits = log2(cell state scale) + 15
324 int stateScaleLog2Rounded;
325 NN_CHECK(tflite::CheckedLog2(prevCellState->scale, &stateScaleLog2Rounded));
326 const int stateIntegerBits = 15 + stateScaleLog2Rounded;
327 // We only support StateIntegerBits == 4
328 NN_CHECK(stateIntegerBits == 4);
329
330 *cellStateOutShape = prevCellState->shape();
331 *outputShape = prevOutput->shape();
332 return true;
333 }
334
335 // The function contatenates 8 input weight matrices into one. Resulting matrix
336 // has a shape [4 * outputSize, outputSize + inputSize]. The matrix is
337 // constructed as follows:
338 // +-----------------------------------+
339 // | recurrentToInput | inputToInput |
340 // |-------------------+---------------|
341 // | recurrentToCell | inputToCell |
342 // |-------------------+---------------|
343 // | recurrentToForget | inputToForget |
344 // |-------------------+---------------|
345 // | recurrentToOutput | inputToOutput |
346 // +-----------------------------------+
concatenateWeights(const std::vector<uint32_t> & weightsDims,uint8_t * weights)347 void QuantizedLSTMCell::concatenateWeights(const std::vector<uint32_t>& weightsDims,
348 uint8_t* weights) {
349 const int outputSize = SizeOfDimension(inputToInputWeights_, 0);
350
351 assignWeightsSubmatrix(inputToInputWeights_, 0 * outputSize, outputSize, weightsDims, weights);
352 assignWeightsSubmatrix(inputToCellWeights_, 1 * outputSize, outputSize, weightsDims, weights);
353 assignWeightsSubmatrix(inputToForgetWeights_, 2 * outputSize, outputSize, weightsDims, weights);
354 assignWeightsSubmatrix(inputToOutputWeights_, 3 * outputSize, outputSize, weightsDims, weights);
355 assignWeightsSubmatrix(recurrentToInputWeights_, 0 * outputSize, 0, weightsDims, weights);
356 assignWeightsSubmatrix(recurrentToCellWeights_, 1 * outputSize, 0, weightsDims, weights);
357 assignWeightsSubmatrix(recurrentToForgetWeights_, 2 * outputSize, 0, weightsDims, weights);
358 assignWeightsSubmatrix(recurrentToOutputWeights_, 3 * outputSize, 0, weightsDims, weights);
359 }
360
361 // The function concatenate four bias vectors of shape [outputSize] into one
362 // vector of shape [4 * outputSize].
concatenateBiases(uint32_t outputSize,int32_t * bias)363 void QuantizedLSTMCell::concatenateBiases(uint32_t outputSize, int32_t* bias) {
364 memcpy(bias + 0 * outputSize, GetBuffer<int32_t>(inputGateBias_), sizeof(int32_t) * outputSize);
365 memcpy(bias + 1 * outputSize, GetBuffer<int32_t>(cellGateBias_), sizeof(int32_t) * outputSize);
366 memcpy(bias + 2 * outputSize, GetBuffer<int32_t>(forgetGateBias_),
367 sizeof(int32_t) * outputSize);
368 memcpy(bias + 3 * outputSize, GetBuffer<int32_t>(outputGateBias_),
369 sizeof(int32_t) * outputSize);
370 }
371
eval()372 bool QuantizedLSTMCell::eval() {
373 NNTRACE_COMP("QuantizedLSTM::eval");
374
375 Shape weightsShape;
376 weightsShape.dimensions = {4 * SizeOfDimension(prevOutput_, 1),
377 SizeOfDimension(input_, 1) + SizeOfDimension(prevOutput_, 1)};
378 std::vector<uint8_t> weights(getNumberOfElements(weightsShape));
379 concatenateWeights(weightsShape.dimensions, weights.data());
380
381 Shape biasShape;
382 biasShape.dimensions = {getSizeOfDimension(weightsShape, 0)};
383 std::vector<int32_t> bias(getNumberOfElements(biasShape));
384 concatenateBiases(SizeOfDimension(prevOutput_, 1), bias.data());
385
386 Shape concatTempShape;
387 concatTempShape.dimensions = {SizeOfDimension(input_, 0), getSizeOfDimension(weightsShape, 1)};
388
389 Shape activationTempShape;
390 activationTempShape.dimensions = {SizeOfDimension(input_, 0),
391 getSizeOfDimension(weightsShape, 0)};
392
393 std::vector<uint8_t> concatTemp(getNumberOfElements(concatTempShape));
394 std::vector<int16_t> activationTemp(getNumberOfElements(activationTempShape));
395
396 // From https://arxiv.org/pdf/1712.05877, for a fully-connected layer,
397 // accumulator multiplier is equal to:
398 // (input scale) * (weights scale) / (fully-connected output scale)
399 // In our case fully-connected output scale is fixed and equal to
400 // 2^(-12) (See LSTMCell definition in TF Lite for more details on that).
401 // But bias scale is set to (input scale) * (weights scale) (also from the
402 // paper), so we can multiply it to an inverse of the fc-output scale to get
403 // the multiplier value:
404 double realAccumMultiplier = 4096 * inputGateBias_->scale;
405 int32_t accumMultiplier;
406 int accumShift;
407 tflite::QuantizeMultiplier(realAccumMultiplier, &accumMultiplier, &accumShift);
408 quantizedLstmStep<4>(
409 // Inputs.
410 GetBuffer<const uint8_t>(input_), convertShapeToDims(input_->shape()),
411 GetBuffer<const uint8_t>(prevOutput_), convertShapeToDims(prevOutput_->shape()),
412 weights.data(), convertShapeToDims(weightsShape), bias.data(),
413 convertShapeToDims(biasShape), GetBuffer<const int16_t>(prevCellState_),
414 convertShapeToDims(prevCellState_->shape()),
415 // Outputs.
416 GetBuffer<int16_t>(cellStateOut_), convertShapeToDims(cellStateOut_->shape()),
417 GetBuffer<uint8_t>(output_), convertShapeToDims(output_->shape()), concatTemp.data(),
418 convertShapeToDims(concatTempShape), activationTemp.data(),
419 convertShapeToDims(activationTempShape), inputToInputWeights_->zeroPoint,
420 accumMultiplier, accumShift);
421 return true;
422 }
423
424 } // namespace nn
425 } // namespace android
426