• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "QLSTM.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 
23 #include "CpuExecutor.h"
24 #include "OperationsExecutionUtils.h"
25 
26 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
27 #include "QuantUtils.h"
28 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
29 
30 namespace android {
31 namespace nn {
32 namespace qlstm {
33 
34 namespace {
35 
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)36 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
37     return context->getInputBuffer(tensor) != nullptr;
38 }
39 
40 }  // namespace
41 
prepare(IOperationExecutionContext * context)42 bool prepare(IOperationExecutionContext* context) {
43     // Check that none of the required inputs are omitted
44     const std::vector<int> requiredTensorInputs = {
45             kInputTensor,
46             kInputToForgetWeightsTensor,
47             kInputToCellWeightsTensor,
48             kInputToOutputWeightsTensor,
49             kRecurrentToForgetWeightsTensor,
50             kRecurrentToCellWeightsTensor,
51             kRecurrentToOutputWeightsTensor,
52             kForgetGateBiasTensor,
53             kCellGateBiasTensor,
54             kOutputGateBiasTensor,
55             kPrevOutputTensor,
56             kPrevCellStateTensor,
57     };
58     for (const int tensor : requiredTensorInputs) {
59         NN_RET_CHECK(!context->isOmittedInput(tensor))
60                 << "required input " << tensor << " is omitted";
61     }
62 
63     const Shape inputShape = context->getInputShape(kInputTensor);
64     const uint32_t inputRank = getNumberOfDimensions(inputShape);
65     NN_RET_CHECK_EQ(inputRank, 2u) << "Invalid input tensor rank: " << inputRank;
66 
67     const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
68     const uint32_t inputSize = getSizeOfDimension(inputShape, 1);
69 
70     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
71     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u);
72     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
73     const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);
74 
75     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
76     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u);
77     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
78     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
79 
80     if (hasTensor(context, kInputToInputWeightsTensor)) {
81         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
82         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u);
83         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
84         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
85     }
86 
87     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
88     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u);
89     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
90     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
91     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
92     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u);
93     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
94     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
95 
96     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
97         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
98         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u);
99         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
100         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
101     }
102 
103     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
104     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u);
105     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
106     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
107     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
108     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u);
109     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
110     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
111 
112     // Make sure the input-gate's parameters are either all present (non-CIFG) or
113     // not at all (CIFG).
114     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
115                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
116                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
117                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
118     NN_RET_CHECK(cifgWeightsAllOrNone);
119 
120     if (hasTensor(context, kCellToInputWeightsTensor)) {
121         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
122         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u);
123         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
124     }
125 
126     if (hasTensor(context, kCellToForgetWeightsTensor)) {
127         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
128         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u);
129         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
130     }
131 
132     if (hasTensor(context, kCellToOutputWeightsTensor)) {
133         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
134         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u);
135         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
136     }
137 
138     // Making sure the peephole weights are there all or none.
139     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
140     const bool peepholeWeightsAllOrNone =
141             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
142              hasTensor(context, kCellToForgetWeightsTensor) &&
143              hasTensor(context, kCellToOutputWeightsTensor)) ||
144             (!hasTensor(context, kCellToInputWeightsTensor) &&
145              !hasTensor(context, kCellToForgetWeightsTensor) &&
146              !hasTensor(context, kCellToOutputWeightsTensor));
147     NN_RET_CHECK(peepholeWeightsAllOrNone);
148 
149     if (!cifgUsed) {
150         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
151         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
152         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u);
153         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
154     } else {
155         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
156                 << "Input gate bias tensor is present when CIFG is used";
157     }
158 
159     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
160     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u);
161     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
162     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
163     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u);
164     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
165     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
166     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u);
167     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);
168 
169     if (hasTensor(context, kProjectionWeightsTensor)) {
170         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
171         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u);
172         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
173         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
174     }
175 
176     if (hasTensor(context, kProjectionBiasTensor)) {
177         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
178         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u);
179         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
180     }
181 
182     const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
183     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u);
184     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
185     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
186     const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
187     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u);
188     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
189     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);
190 
191     if (hasTensor(context, kInputLayerNormTensor)) {
192         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
193         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u);
194         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
195     }
196 
197     if (hasTensor(context, kForgetLayerNormTensor)) {
198         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
199         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u);
200         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
201     }
202 
203     if (hasTensor(context, kCellLayerNormTensor)) {
204         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
205         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u);
206         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
207     }
208 
209     if (hasTensor(context, kOutputLayerNormTensor)) {
210         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
211         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u);
212         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
213     }
214 
215     if (cifgUsed) {
216         NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
217                 << "Input layer norm weights tensor is present when CIFG is used";
218         const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
219                                                     hasTensor(context, kCellLayerNormTensor) &&
220                                                     hasTensor(context, kOutputLayerNormTensor)) ||
221                                                    (!hasTensor(context, kForgetLayerNormTensor) &&
222                                                     !hasTensor(context, kCellLayerNormTensor) &&
223                                                     !hasTensor(context, kOutputLayerNormTensor));
224         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
225     } else {
226         const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
227                                                 hasTensor(context, kForgetLayerNormTensor) &&
228                                                 hasTensor(context, kCellLayerNormTensor) &&
229                                                 hasTensor(context, kOutputLayerNormTensor)) ||
230                                                (!hasTensor(context, kInputLayerNormTensor) &&
231                                                 !hasTensor(context, kForgetLayerNormTensor) &&
232                                                 !hasTensor(context, kCellLayerNormTensor) &&
233                                                 !hasTensor(context, kOutputLayerNormTensor));
234         NN_RET_CHECK(layerNormWeightsAllOrNone);
235     }
236 
237     const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
238     Shape outputShape = context->getOutputShape(kOutputTensor);
239     outputShape.dimensions = prevOutputShape.dimensions;
240 
241     const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
242     Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
243     cellStateOutShape.dimensions = prevCellStateShape.dimensions;
244 
245     return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
246            context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
247            context->setOutputShape(kOutputTensor, outputShape);
248 }
249 
250 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
execute(IOperationExecutionContext * context)251 bool execute(IOperationExecutionContext* context) {
252     // Gets the inputs.
253     const Shape inputShape = context->getInputShape(kInputTensor);
254     const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
255     const Shape recurrentToInputWeightsShape =
256             context->getInputShape(kRecurrentToInputWeightsTensor);
257     const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
258     const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
259     const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
260     const Shape recurrentToForgetWeightsShape =
261             context->getInputShape(kRecurrentToForgetWeightsTensor);
262     const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
263     const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
264     const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
265     const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
266     const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
267     const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
268     const Shape recurrentToOutputWeightsShape =
269             context->getInputShape(kRecurrentToOutputWeightsTensor);
270     const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
271     const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
272     const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
273     const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
274     const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
275 
276     const uint32_t batchSize = inputShape.dimensions[0];
277     const uint32_t inputSize = inputShape.dimensions[1];
278     const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
279     const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];
280 
281     const float cellClip = context->getInputValue<float>(kCellClip);
282     const float projectionClip = context->getInputValue<float>(kProjectionClip);
283     const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
284     const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
285     const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
286     const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
287     const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
288     const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);
289 
290     const int8_t* inputBuffer =
291             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));
292 
293     const int8_t* inputToInputWeightsBuffer =
294             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
295     const bool useCifg = (inputToInputWeightsBuffer == nullptr);
296     const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
297             context->getInputBuffer(kRecurrentToInputWeightsTensor));
298     const int16_t* cellToInputBuffer =
299             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
300     const int16_t* inputLayerNormBuffer =
301             reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
302     const int32_t* inputBiasBuffer =
303             reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));
304 
305     const int8_t* inputToForgetWeightsBuffer =
306             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
307     const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
308             context->getInputBuffer(kRecurrentToForgetWeightsTensor));
309     const int16_t* cellToForgetBuffer =
310             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
311     const int16_t* forgetLayerNormBuffer =
312             reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
313     const int32_t* forgetBiasBuffer =
314             reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));
315 
316     const int8_t* inputToCellWeightsBuffer =
317             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
318     const int8_t* recurrentToCellWeightsBuffer =
319             reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
320     const int16_t* cellLayerNormBuffer =
321             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
322     const int32_t* cellBiasBuffer =
323             reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));
324 
325     const int8_t* inputToOutputWeightsBuffer =
326             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
327     const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
328             context->getInputBuffer(kRecurrentToOutputWeightsTensor));
329     const int16_t* cellToOutputBuffer =
330             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
331     const int16_t* outputLayerNormBuffer =
332             reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
333     const int32_t* outputBiasBuffer =
334             reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));
335 
336     const int8_t* projectionWeightsBuffer =
337             reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
338     const int32_t* projectionBiasBuffer =
339             reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));
340 
341     const int8_t* prevOutputBuffer =
342             reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
343     const int16_t* prevCellStateBuffer =
344             reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));
345 
346     uint8_t* outputStateBuffer =
347             reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
348     int16_t* cellStateBuffer =
349             reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
350     int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));
351 
352     // Calculates and decomposes effective scales.
353     // This is for optimizing the matmul calculation.
354     int cellShift;
355     NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
356     NN_RET_CHECK(cellShift <= -9);
357 
358     int32_t inputToInputEffectiveScaleA;
359     int32_t inputToInputEffectiveScaleB;
360     int32_t recurrentToInputEffectiveScaleA;
361     int32_t recurrentToInputEffectiveScaleB;
362     int32_t cellToInputEffectiveScaleA;
363     int32_t cellToInputEffectiveScaleB;
364     if (!useCifg) {
365         const float inputToInputEffectiveScale =
366                 inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
367         NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
368                                         &inputToInputEffectiveScaleB));
369         const float recurrentToInputEffectiveScale =
370                 recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
371         NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
372                                         &recurrentToInputEffectiveScaleA,
373                                         &recurrentToInputEffectiveScaleB));
374         if (cellToInputBuffer != nullptr) {
375             const float cellToInputEffectiveScale =
376                     std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
377             NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
378                                             &cellToInputEffectiveScaleB));
379         }
380     }
381 
382     int32_t inputLayerNormScaleA;
383     int32_t inputLayerNormScaleB;
384     if (inputLayerNormBuffer != nullptr) {
385         NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
386                                         &inputLayerNormScaleB));
387     }
388 
389     const float inputToForgetEffectiveScale =
390             inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
391     int32_t inputToForgetEffectiveScaleA;
392     int32_t inputToForgetEffectiveScaleB;
393     NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
394                                     &inputToForgetEffectiveScaleB));
395     const float recurrentToForgetEffectiveScale =
396             recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
397     int32_t recurrentToForgetEffectiveScaleA;
398     int32_t recurrentToForgetEffectiveScaleB;
399     NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
400                                     &recurrentToForgetEffectiveScaleA,
401                                     &recurrentToForgetEffectiveScaleB));
402     int32_t cellToForgetEffectiveScaleA;
403     int32_t cellToForgetEffectiveScaleB;
404     if (cellToForgetBuffer != nullptr) {
405         const float cellToForgetEffectiveScale =
406                 std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
407         NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
408                                         &cellToForgetEffectiveScaleB));
409     }
410     int32_t forgetLayerNormScaleA;
411     int32_t forgetLayerNormScaleB;
412     if (forgetLayerNormBuffer != nullptr) {
413         NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
414                                         &forgetLayerNormScaleB));
415     }
416 
417     const float inputToCellEffectiveScale =
418             inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
419     int32_t inputToCellEffectiveScaleA;
420     int32_t inputToCellEffectiveScaleB;
421     NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
422                                     &inputToCellEffectiveScaleB));
423     const float recurrentToCellEffectiveScale =
424             recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
425     int32_t recurrentToCellEffectiveScaleA;
426     int32_t recurrentToCellEffectiveScaleB;
427     NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
428                                     &recurrentToCellEffectiveScaleB));
429 
430     int32_t cellLayerNormScaleA;
431     int32_t cellLayerNormScaleB;
432     if (cellLayerNormBuffer != nullptr) {
433         NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
434                                         &cellLayerNormScaleB));
435     }
436 
437     const float inputToOutputEffectiveScale =
438             inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
439     int32_t inputToOutputEffectiveScaleA;
440     int32_t inputToOutputEffectiveScaleB;
441     NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
442                                     &inputToOutputEffectiveScaleB));
443     const float recurrentToOutputEffectiveScale =
444             recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
445     int32_t recurrentToOutputEffectiveScaleA;
446     int32_t recurrentToOutputEffectiveScaleB;
447     NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
448                                     &recurrentToOutputEffectiveScaleA,
449                                     &recurrentToOutputEffectiveScaleB));
450     int32_t cellToOutputEffectiveScaleA;
451     int32_t cellToOutputEffectiveScaleB;
452     if (cellToOutputBuffer != nullptr) {
453         const float cellToOutputEffectiveScale =
454                 std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
455         NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
456                                         &cellToOutputEffectiveScaleB));
457     }
458     int32_t outputLayerNormScaleA;
459     int32_t outputLayerNormScaleB;
460     if (outputLayerNormBuffer != nullptr) {
461         NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
462                                         &outputLayerNormScaleB));
463     }
464 
465     const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
466     int32_t hiddenStateEffectiveScaleA;
467     int32_t hiddenStateEffectiveScaleB;
468     NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
469                                     &hiddenStateEffectiveScaleB));
470 
471     int32_t projectionEffectiveScaleA;
472     int32_t projectionEffectiveScaleB;
473     if (projectionWeightsBuffer != nullptr) {
474         const float projectionEffectiveScale =
475                 projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
476         NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
477                                         &projectionEffectiveScaleB));
478     }
479 
480     // Calculates quantized clipping parameters.
481     int16_t quantizedCellClip = 0;
482     if (cellClip > 0.0) {
483         quantizedCellClip = static_cast<int32_t>(
484                 std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
485     }
486     int8_t quantizedProjectionClip = 0;
487     if (projectionClip > 0.0) {
488         quantizedProjectionClip = static_cast<int32_t>(
489                 std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
490     }
491 
492     // Calculates effective bias.
493     // This is for optimizing the matmul calculation.
494     std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
495     std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
496     if (!useCifg) {
497         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
498                 -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
499                 /*bias=*/nullptr, &inputToInputEffectiveBias));
500         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
501                 -prevOutputShape.offset, recurrentToInputWeightsBuffer,
502                 recurrentToInputWeightsShape,
503                 /*bias=*/nullptr, &recurrentToInputEffectiveBias));
504     }
505 
506     std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
507     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
508             -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
509             /*bias=*/nullptr, &inputToForgetEffectiveBias));
510     std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
511     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
512             -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
513             /*bias=*/nullptr, &recurrentToForgetEffectiveBias));
514 
515     std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
516     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
517             -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
518             /*bias=*/nullptr, &inputToCellEffectiveBias));
519     std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
520     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
521             -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
522             /*bias=*/nullptr, &recurrentToCellEffectiveBias));
523 
524     std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
525     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
526             -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
527             /*bias=*/nullptr, &inputToOutputEffectiveBias));
528     std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
529     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
530             -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
531             /*bias=*/nullptr, &recurrentToOutputEffectiveBias));
532 
533     std::unique_ptr<int32_t[]> projectionEffectiveBias;
534     if (projectionBiasBuffer != nullptr) {
535         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
536                 hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
537                 projectionBiasBuffer, &projectionEffectiveBias));
538     }
539 
540     // Temporary buffers.
541     std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
542     std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
543     std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
544     std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
545     std::vector<int8_t> buffer8(batchSize * numUnits);
546 
547     // To avoid overflow when calculating layer norm.
548     const int32_t inputInvLargeValue =
549             std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
550     const int32_t forgetInvLargeValue =
551             std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
552     const int32_t cellInvLargeValue =
553             std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
554     const int32_t outputInvLargeValue =
555             std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));
556 
557     // Forget gate.
558     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
559                                         inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
560                                         inputToForgetEffectiveScaleB, batchSize, inputSize,
561                                         numUnits,
562                                         /*outputZeroPoint=*/0, forgetGateBuffer.data());
563     MatrixBatchVectorMultiplyAccumulate(
564             prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
565             recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
566             outputSize, numUnits,
567             /*outputZeroPoint=*/0, forgetGateBuffer.data());
568     if (cellToForgetBuffer != nullptr) {
569         VectorBatchVectorCwiseProductAccumulate(
570                 cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
571                 cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
572     }
573     if (forgetLayerNormBuffer != nullptr) {
574         ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
575                        forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
576                        numUnits, forgetGateBuffer.data());
577     }
578     ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());
579 
580     // Modulation gate.
581     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
582                                         inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
583                                         inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
584                                         /*outputZeroPoint=*/0, cellGateBuffer.data());
585     MatrixBatchVectorMultiplyAccumulate(
586             prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
587             recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
588             numUnits,
589             /*outputZeroPoint=*/0, cellGateBuffer.data());
590     if (cellLayerNormBuffer != nullptr) {
591         ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
592                        cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
593                        numUnits, cellGateBuffer.data());
594     }
595     ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());
596 
597     // Input gate.
598     if (useCifg) {
599         Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
600     } else {
601         MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
602                                             inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
603                                             inputToInputEffectiveScaleB, batchSize, inputSize,
604                                             numUnits,
605                                             /*outputZeroPoint=*/0, inputGateBuffer.data());
606         MatrixBatchVectorMultiplyAccumulate(
607                 prevOutputBuffer, recurrentToInputEffectiveBias.get(),
608                 recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
609                 recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
610                 /*outputZeroPoint=*/0, inputGateBuffer.data());
611         if (cellToInputBuffer != nullptr) {
612             VectorBatchVectorCwiseProductAccumulate(
613                     cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
614                     cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
615         }
616         if (inputLayerNormBuffer != nullptr) {
617             ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
618                            inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
619                            batchSize, numUnits, inputGateBuffer.data());
620         }
621         ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
622     }
623 
624     // Cell.
625     CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
626              /*shift=*/15, forgetGateBuffer.data());
627     CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
628              cellGateBuffer.data());
629     CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
630     if (quantizedCellClip > 0) {
631         CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
632     }
633 
634     // Output gate.
635     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
636                                         inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
637                                         inputToOutputEffectiveScaleB, batchSize, inputSize,
638                                         numUnits,
639                                         /*outputZeroPoint=*/0, outputGateBuffer.data());
640     MatrixBatchVectorMultiplyAccumulate(
641             prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
642             recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
643             outputSize, numUnits,
644             /*outputZeroPoint=*/0, outputGateBuffer.data());
645     if (cellToOutputBuffer != nullptr) {
646         VectorBatchVectorCwiseProductAccumulate(
647                 cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
648                 cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
649     }
650     if (outputLayerNormBuffer != nullptr) {
651         ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
652                        outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
653                        numUnits, outputGateBuffer.data());
654     }
655     ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());
656 
657     // Hidden.
658     ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
659     CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
660              hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());
661 
662     // Projection.
663     if (projectionWeightsBuffer != nullptr) {
664         memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t));
665         MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
666                                             projectionWeightsBuffer, projectionEffectiveScaleA,
667                                             projectionEffectiveScaleB, batchSize, numUnits,
668                                             outputSize, prevOutputShape.offset, outputBuffer);
669         if (quantizedProjectionClip > 0) {
670             CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
671         }
672     } else {
673         std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer);
674     }
675 
676     // Copy output to output state out.
677     for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
678         outputStateBuffer[i] = outputBuffer[i];
679     }
680 
681     return true;
682 }
683 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
684 
685 }  // namespace qlstm
686 
687 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(QUANTIZED_LSTM, qlstm::prepare, qlstm::execute,
688                                          .allowOmittedOperand = true);
689 
690 }  // namespace nn
691 }  // namespace android
692