1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "QLSTM.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22
23 #include "CpuExecutor.h"
24 #include "OperationsExecutionUtils.h"
25
26 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
27 #include "QuantUtils.h"
28 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
29
30 namespace android {
31 namespace nn {
32 namespace qlstm {
33
34 namespace {
35
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)36 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
37 return context->getInputBuffer(tensor) != nullptr;
38 }
39
40 } // namespace
41
prepare(IOperationExecutionContext * context)42 bool prepare(IOperationExecutionContext* context) {
43 // Check that none of the required inputs are omitted
44 const std::vector<int> requiredTensorInputs = {
45 kInputTensor,
46 kInputToForgetWeightsTensor,
47 kInputToCellWeightsTensor,
48 kInputToOutputWeightsTensor,
49 kRecurrentToForgetWeightsTensor,
50 kRecurrentToCellWeightsTensor,
51 kRecurrentToOutputWeightsTensor,
52 kForgetGateBiasTensor,
53 kCellGateBiasTensor,
54 kOutputGateBiasTensor,
55 kPrevOutputTensor,
56 kPrevCellStateTensor,
57 };
58 for (const int tensor : requiredTensorInputs) {
59 NN_RET_CHECK(!context->isOmittedInput(tensor))
60 << "required input " << tensor << " is omitted";
61 }
62
63 const Shape inputShape = context->getInputShape(kInputTensor);
64 const uint32_t inputRank = getNumberOfDimensions(inputShape);
65 NN_RET_CHECK_EQ(inputRank, 2u) << "Invalid input tensor rank: " << inputRank;
66
67 const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
68 const uint32_t inputSize = getSizeOfDimension(inputShape, 1);
69
70 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
71 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u);
72 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
73 const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);
74
75 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
76 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u);
77 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
78 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
79
80 if (hasTensor(context, kInputToInputWeightsTensor)) {
81 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
82 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u);
83 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
84 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
85 }
86
87 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
88 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u);
89 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
90 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
91 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
92 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u);
93 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
94 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
95
96 if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
97 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
98 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u);
99 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
100 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
101 }
102
103 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
104 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u);
105 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
106 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
107 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
108 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u);
109 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
110 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
111
112 // Make sure the input-gate's parameters are either all present (non-CIFG) or
113 // not at all (CIFG).
114 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
115 hasTensor(context, kRecurrentToInputWeightsTensor)) ||
116 (!hasTensor(context, kInputToInputWeightsTensor) &&
117 !hasTensor(context, kRecurrentToInputWeightsTensor));
118 NN_RET_CHECK(cifgWeightsAllOrNone);
119
120 if (hasTensor(context, kCellToInputWeightsTensor)) {
121 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
122 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u);
123 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
124 }
125
126 if (hasTensor(context, kCellToForgetWeightsTensor)) {
127 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
128 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u);
129 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
130 }
131
132 if (hasTensor(context, kCellToOutputWeightsTensor)) {
133 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
134 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u);
135 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
136 }
137
138 // Making sure the peephole weights are there all or none.
139 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
140 const bool peepholeWeightsAllOrNone =
141 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
142 hasTensor(context, kCellToForgetWeightsTensor) &&
143 hasTensor(context, kCellToOutputWeightsTensor)) ||
144 (!hasTensor(context, kCellToInputWeightsTensor) &&
145 !hasTensor(context, kCellToForgetWeightsTensor) &&
146 !hasTensor(context, kCellToOutputWeightsTensor));
147 NN_RET_CHECK(peepholeWeightsAllOrNone);
148
149 if (!cifgUsed) {
150 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
151 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
152 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u);
153 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
154 } else {
155 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
156 << "Input gate bias tensor is present when CIFG is used";
157 }
158
159 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
160 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u);
161 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
162 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
163 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u);
164 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
165 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
166 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u);
167 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);
168
169 if (hasTensor(context, kProjectionWeightsTensor)) {
170 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
171 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u);
172 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
173 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
174 }
175
176 if (hasTensor(context, kProjectionBiasTensor)) {
177 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
178 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u);
179 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
180 }
181
182 const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
183 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u);
184 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
185 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
186 const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
187 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u);
188 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
189 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);
190
191 if (hasTensor(context, kInputLayerNormTensor)) {
192 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
193 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u);
194 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
195 }
196
197 if (hasTensor(context, kForgetLayerNormTensor)) {
198 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
199 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u);
200 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
201 }
202
203 if (hasTensor(context, kCellLayerNormTensor)) {
204 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
205 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u);
206 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
207 }
208
209 if (hasTensor(context, kOutputLayerNormTensor)) {
210 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
211 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u);
212 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
213 }
214
215 if (cifgUsed) {
216 NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
217 << "Input layer norm weights tensor is present when CIFG is used";
218 const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
219 hasTensor(context, kCellLayerNormTensor) &&
220 hasTensor(context, kOutputLayerNormTensor)) ||
221 (!hasTensor(context, kForgetLayerNormTensor) &&
222 !hasTensor(context, kCellLayerNormTensor) &&
223 !hasTensor(context, kOutputLayerNormTensor));
224 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
225 } else {
226 const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
227 hasTensor(context, kForgetLayerNormTensor) &&
228 hasTensor(context, kCellLayerNormTensor) &&
229 hasTensor(context, kOutputLayerNormTensor)) ||
230 (!hasTensor(context, kInputLayerNormTensor) &&
231 !hasTensor(context, kForgetLayerNormTensor) &&
232 !hasTensor(context, kCellLayerNormTensor) &&
233 !hasTensor(context, kOutputLayerNormTensor));
234 NN_RET_CHECK(layerNormWeightsAllOrNone);
235 }
236
237 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
238 Shape outputShape = context->getOutputShape(kOutputTensor);
239 outputShape.dimensions = prevOutputShape.dimensions;
240
241 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
242 Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
243 cellStateOutShape.dimensions = prevCellStateShape.dimensions;
244
245 return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
246 context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
247 context->setOutputShape(kOutputTensor, outputShape);
248 }
249
250 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
execute(IOperationExecutionContext * context)251 bool execute(IOperationExecutionContext* context) {
252 // Gets the inputs.
253 const Shape inputShape = context->getInputShape(kInputTensor);
254 const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
255 const Shape recurrentToInputWeightsShape =
256 context->getInputShape(kRecurrentToInputWeightsTensor);
257 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
258 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
259 const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
260 const Shape recurrentToForgetWeightsShape =
261 context->getInputShape(kRecurrentToForgetWeightsTensor);
262 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
263 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
264 const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
265 const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
266 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
267 const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
268 const Shape recurrentToOutputWeightsShape =
269 context->getInputShape(kRecurrentToOutputWeightsTensor);
270 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
271 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
272 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
273 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
274 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
275
276 const uint32_t batchSize = inputShape.dimensions[0];
277 const uint32_t inputSize = inputShape.dimensions[1];
278 const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
279 const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];
280
281 const float cellClip = context->getInputValue<float>(kCellClip);
282 const float projectionClip = context->getInputValue<float>(kProjectionClip);
283 const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
284 const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
285 const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
286 const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
287 const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
288 const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);
289
290 const int8_t* inputBuffer =
291 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));
292
293 const int8_t* inputToInputWeightsBuffer =
294 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
295 const bool useCifg = (inputToInputWeightsBuffer == nullptr);
296 const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
297 context->getInputBuffer(kRecurrentToInputWeightsTensor));
298 const int16_t* cellToInputBuffer =
299 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
300 const int16_t* inputLayerNormBuffer =
301 reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
302 const int32_t* inputBiasBuffer =
303 reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));
304
305 const int8_t* inputToForgetWeightsBuffer =
306 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
307 const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
308 context->getInputBuffer(kRecurrentToForgetWeightsTensor));
309 const int16_t* cellToForgetBuffer =
310 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
311 const int16_t* forgetLayerNormBuffer =
312 reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
313 const int32_t* forgetBiasBuffer =
314 reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));
315
316 const int8_t* inputToCellWeightsBuffer =
317 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
318 const int8_t* recurrentToCellWeightsBuffer =
319 reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
320 const int16_t* cellLayerNormBuffer =
321 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
322 const int32_t* cellBiasBuffer =
323 reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));
324
325 const int8_t* inputToOutputWeightsBuffer =
326 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
327 const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
328 context->getInputBuffer(kRecurrentToOutputWeightsTensor));
329 const int16_t* cellToOutputBuffer =
330 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
331 const int16_t* outputLayerNormBuffer =
332 reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
333 const int32_t* outputBiasBuffer =
334 reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));
335
336 const int8_t* projectionWeightsBuffer =
337 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
338 const int32_t* projectionBiasBuffer =
339 reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));
340
341 const int8_t* prevOutputBuffer =
342 reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
343 const int16_t* prevCellStateBuffer =
344 reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));
345
346 uint8_t* outputStateBuffer =
347 reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
348 int16_t* cellStateBuffer =
349 reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
350 int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));
351
352 // Calculates and decomposes effective scales.
353 // This is for optimizing the matmul calculation.
354 int cellShift;
355 NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
356 NN_RET_CHECK(cellShift <= -9);
357
358 int32_t inputToInputEffectiveScaleA;
359 int32_t inputToInputEffectiveScaleB;
360 int32_t recurrentToInputEffectiveScaleA;
361 int32_t recurrentToInputEffectiveScaleB;
362 int32_t cellToInputEffectiveScaleA;
363 int32_t cellToInputEffectiveScaleB;
364 if (!useCifg) {
365 const float inputToInputEffectiveScale =
366 inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
367 NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
368 &inputToInputEffectiveScaleB));
369 const float recurrentToInputEffectiveScale =
370 recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
371 NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
372 &recurrentToInputEffectiveScaleA,
373 &recurrentToInputEffectiveScaleB));
374 if (cellToInputBuffer != nullptr) {
375 const float cellToInputEffectiveScale =
376 std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
377 NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
378 &cellToInputEffectiveScaleB));
379 }
380 }
381
382 int32_t inputLayerNormScaleA;
383 int32_t inputLayerNormScaleB;
384 if (inputLayerNormBuffer != nullptr) {
385 NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
386 &inputLayerNormScaleB));
387 }
388
389 const float inputToForgetEffectiveScale =
390 inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
391 int32_t inputToForgetEffectiveScaleA;
392 int32_t inputToForgetEffectiveScaleB;
393 NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
394 &inputToForgetEffectiveScaleB));
395 const float recurrentToForgetEffectiveScale =
396 recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
397 int32_t recurrentToForgetEffectiveScaleA;
398 int32_t recurrentToForgetEffectiveScaleB;
399 NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
400 &recurrentToForgetEffectiveScaleA,
401 &recurrentToForgetEffectiveScaleB));
402 int32_t cellToForgetEffectiveScaleA;
403 int32_t cellToForgetEffectiveScaleB;
404 if (cellToForgetBuffer != nullptr) {
405 const float cellToForgetEffectiveScale =
406 std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
407 NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
408 &cellToForgetEffectiveScaleB));
409 }
410 int32_t forgetLayerNormScaleA;
411 int32_t forgetLayerNormScaleB;
412 if (forgetLayerNormBuffer != nullptr) {
413 NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
414 &forgetLayerNormScaleB));
415 }
416
417 const float inputToCellEffectiveScale =
418 inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
419 int32_t inputToCellEffectiveScaleA;
420 int32_t inputToCellEffectiveScaleB;
421 NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
422 &inputToCellEffectiveScaleB));
423 const float recurrentToCellEffectiveScale =
424 recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
425 int32_t recurrentToCellEffectiveScaleA;
426 int32_t recurrentToCellEffectiveScaleB;
427 NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
428 &recurrentToCellEffectiveScaleB));
429
430 int32_t cellLayerNormScaleA;
431 int32_t cellLayerNormScaleB;
432 if (cellLayerNormBuffer != nullptr) {
433 NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
434 &cellLayerNormScaleB));
435 }
436
437 const float inputToOutputEffectiveScale =
438 inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
439 int32_t inputToOutputEffectiveScaleA;
440 int32_t inputToOutputEffectiveScaleB;
441 NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
442 &inputToOutputEffectiveScaleB));
443 const float recurrentToOutputEffectiveScale =
444 recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
445 int32_t recurrentToOutputEffectiveScaleA;
446 int32_t recurrentToOutputEffectiveScaleB;
447 NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
448 &recurrentToOutputEffectiveScaleA,
449 &recurrentToOutputEffectiveScaleB));
450 int32_t cellToOutputEffectiveScaleA;
451 int32_t cellToOutputEffectiveScaleB;
452 if (cellToOutputBuffer != nullptr) {
453 const float cellToOutputEffectiveScale =
454 std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
455 NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
456 &cellToOutputEffectiveScaleB));
457 }
458 int32_t outputLayerNormScaleA;
459 int32_t outputLayerNormScaleB;
460 if (outputLayerNormBuffer != nullptr) {
461 NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
462 &outputLayerNormScaleB));
463 }
464
465 const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
466 int32_t hiddenStateEffectiveScaleA;
467 int32_t hiddenStateEffectiveScaleB;
468 NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
469 &hiddenStateEffectiveScaleB));
470
471 int32_t projectionEffectiveScaleA;
472 int32_t projectionEffectiveScaleB;
473 if (projectionWeightsBuffer != nullptr) {
474 const float projectionEffectiveScale =
475 projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
476 NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
477 &projectionEffectiveScaleB));
478 }
479
480 // Calculates quantized clipping parameters.
481 int16_t quantizedCellClip = 0;
482 if (cellClip > 0.0) {
483 quantizedCellClip = static_cast<int32_t>(
484 std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
485 }
486 int8_t quantizedProjectionClip = 0;
487 if (projectionClip > 0.0) {
488 quantizedProjectionClip = static_cast<int32_t>(
489 std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
490 }
491
492 // Calculates effective bias.
493 // This is for optimizing the matmul calculation.
494 std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
495 std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
496 if (!useCifg) {
497 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
498 -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
499 /*bias=*/nullptr, &inputToInputEffectiveBias));
500 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
501 -prevOutputShape.offset, recurrentToInputWeightsBuffer,
502 recurrentToInputWeightsShape,
503 /*bias=*/nullptr, &recurrentToInputEffectiveBias));
504 }
505
506 std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
507 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
508 -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
509 /*bias=*/nullptr, &inputToForgetEffectiveBias));
510 std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
511 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
512 -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
513 /*bias=*/nullptr, &recurrentToForgetEffectiveBias));
514
515 std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
516 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
517 -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
518 /*bias=*/nullptr, &inputToCellEffectiveBias));
519 std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
520 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
521 -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
522 /*bias=*/nullptr, &recurrentToCellEffectiveBias));
523
524 std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
525 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
526 -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
527 /*bias=*/nullptr, &inputToOutputEffectiveBias));
528 std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
529 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
530 -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
531 /*bias=*/nullptr, &recurrentToOutputEffectiveBias));
532
533 std::unique_ptr<int32_t[]> projectionEffectiveBias;
534 if (projectionBiasBuffer != nullptr) {
535 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
536 hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
537 projectionBiasBuffer, &projectionEffectiveBias));
538 }
539
540 // Temporary buffers.
541 std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
542 std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
543 std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
544 std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
545 std::vector<int8_t> buffer8(batchSize * numUnits);
546
547 // To avoid overflow when calculating layer norm.
548 const int32_t inputInvLargeValue =
549 std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
550 const int32_t forgetInvLargeValue =
551 std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
552 const int32_t cellInvLargeValue =
553 std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
554 const int32_t outputInvLargeValue =
555 std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));
556
557 // Forget gate.
558 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
559 inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
560 inputToForgetEffectiveScaleB, batchSize, inputSize,
561 numUnits,
562 /*outputZeroPoint=*/0, forgetGateBuffer.data());
563 MatrixBatchVectorMultiplyAccumulate(
564 prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
565 recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
566 outputSize, numUnits,
567 /*outputZeroPoint=*/0, forgetGateBuffer.data());
568 if (cellToForgetBuffer != nullptr) {
569 VectorBatchVectorCwiseProductAccumulate(
570 cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
571 cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
572 }
573 if (forgetLayerNormBuffer != nullptr) {
574 ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
575 forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
576 numUnits, forgetGateBuffer.data());
577 }
578 ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());
579
580 // Modulation gate.
581 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
582 inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
583 inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
584 /*outputZeroPoint=*/0, cellGateBuffer.data());
585 MatrixBatchVectorMultiplyAccumulate(
586 prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
587 recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
588 numUnits,
589 /*outputZeroPoint=*/0, cellGateBuffer.data());
590 if (cellLayerNormBuffer != nullptr) {
591 ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
592 cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
593 numUnits, cellGateBuffer.data());
594 }
595 ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());
596
597 // Input gate.
598 if (useCifg) {
599 Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
600 } else {
601 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
602 inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
603 inputToInputEffectiveScaleB, batchSize, inputSize,
604 numUnits,
605 /*outputZeroPoint=*/0, inputGateBuffer.data());
606 MatrixBatchVectorMultiplyAccumulate(
607 prevOutputBuffer, recurrentToInputEffectiveBias.get(),
608 recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
609 recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
610 /*outputZeroPoint=*/0, inputGateBuffer.data());
611 if (cellToInputBuffer != nullptr) {
612 VectorBatchVectorCwiseProductAccumulate(
613 cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
614 cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
615 }
616 if (inputLayerNormBuffer != nullptr) {
617 ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
618 inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
619 batchSize, numUnits, inputGateBuffer.data());
620 }
621 ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
622 }
623
624 // Cell.
625 CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
626 /*shift=*/15, forgetGateBuffer.data());
627 CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
628 cellGateBuffer.data());
629 CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
630 if (quantizedCellClip > 0) {
631 CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
632 }
633
634 // Output gate.
635 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
636 inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
637 inputToOutputEffectiveScaleB, batchSize, inputSize,
638 numUnits,
639 /*outputZeroPoint=*/0, outputGateBuffer.data());
640 MatrixBatchVectorMultiplyAccumulate(
641 prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
642 recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
643 outputSize, numUnits,
644 /*outputZeroPoint=*/0, outputGateBuffer.data());
645 if (cellToOutputBuffer != nullptr) {
646 VectorBatchVectorCwiseProductAccumulate(
647 cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
648 cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
649 }
650 if (outputLayerNormBuffer != nullptr) {
651 ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
652 outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
653 numUnits, outputGateBuffer.data());
654 }
655 ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());
656
657 // Hidden.
658 ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
659 CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
660 hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());
661
662 // Projection.
663 if (projectionWeightsBuffer != nullptr) {
664 memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t));
665 MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
666 projectionWeightsBuffer, projectionEffectiveScaleA,
667 projectionEffectiveScaleB, batchSize, numUnits,
668 outputSize, prevOutputShape.offset, outputBuffer);
669 if (quantizedProjectionClip > 0) {
670 CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
671 }
672 } else {
673 std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer);
674 }
675
676 // Copy output to output state out.
677 for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
678 outputStateBuffer[i] = outputBuffer[i];
679 }
680
681 return true;
682 }
683 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
684
685 } // namespace qlstm
686
687 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(QUANTIZED_LSTM, qlstm::prepare, qlstm::execute,
688 .allowOmittedOperand = true);
689
690 } // namespace nn
691 } // namespace android
692