• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefQLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "LstmUtils.hpp"
11 #include "RefWorkloadUtils.hpp"
12 
13 namespace armnn
14 {
15 
RefQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info)16 RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17         : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
18         , m_InputToInputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
19         , m_InputToForgetWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
20         , m_InputToCellWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
21         , m_InputToOutputWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
22 
23         , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
24         , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
25         , m_RecurrentToCellWeightsTensor  (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
26         , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 
28         , m_CellToInputWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
29         , m_CellToForgetWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
30         , m_CellToOutputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
31 
32         , m_InputGateBiasTensor           (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
33         , m_ForgetGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
34         , m_CellBiasTensor                (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
35         , m_OutputGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
36 
37         , m_ProjectionWeightsTensor       (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
38         , m_ProjectionBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
39 
40         , m_InputLayerNormWeightsTensor   (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
41         , m_ForgetLayerNormWeightsTensor  (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
42         , m_CellLayerNormWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
43         , m_OutputLayerNormWeightsTensor  (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
Execute() const46 void RefQLstmWorkload::Execute() const
47 {
48     // This is a porting of the QLSTM::Execute() method in the Android code base
49     // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
50     // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
51     // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
52     const DataType& internalType = armnn::DataType::QSymmS16;
53 
54     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
55     const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]);
56     const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]);
57 
58     const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]);
59     const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]);
60     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]);
61 
62     const TensorShape& inputShape = inputInfo.GetShape();
63     const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
64     const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
65 
66     // Infer numBatches, inputSize, outputSize and numUnits
67     const uint32_t numBatches = inputShape[0];
68     const uint32_t inputSize  = inputShape[1];
69     const uint32_t outputSize = outputStateInShape[1];
70     const uint32_t numUnits   = cellStateInShape[1];
71 
72     // Optional param settings
73     const bool cifgEnabled      = m_Data.m_Parameters.m_CifgEnabled;
74     const bool peepholeEnabled  = m_Data.m_Parameters.m_PeepholeEnabled;
75     const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
76     const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
77 
78     // Input decoders
79     std::unique_ptr<Decoder<float>> inputDecoder =
80             MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
81     std::unique_ptr<Decoder<float>> outputStateInDecoder =
82             MakeDecoder<float>(outputStateInInfo, m_Data.m_Inputs[1]->Map());
83     std::unique_ptr<Decoder<float>> cellStateInDecoder =
84             MakeDecoder<float>(cellStateInInfo, m_Data.m_Inputs[2]->Map());
85 
86     // Output decoders
87     std::unique_ptr<Decoder<float>> outputStateOutDecoder =
88             MakeDecoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
89     std::unique_ptr<Decoder<float>> cellStateOutDecoder =
90             MakeDecoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
91     std::unique_ptr<Decoder<float>> outputDecoder =
92             MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
93 
94     // Output encoders
95     std::unique_ptr<Encoder<float>> outputStateOutEncoder =
96             MakeEncoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
97     std::unique_ptr<Encoder<float>> cellStateOutEncoder =
98             MakeEncoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
99     std::unique_ptr<Encoder<float>> outputEncoder =
100             MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
101 
102     // Weights decoders
103     std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
104             m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
105     std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
106             m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
107     std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
108             m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
109 
110     std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
111             m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
112     std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
113             m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
114     std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
115             m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
116 
117     // Optional CIFG params
118     std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
119     std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
120     std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
121 
122     // Optional Peephole params
123     std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
124     std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
125     std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
126 
127     // Optional Projection params
128     std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
129     std::unique_ptr<Decoder<float>> projectionBiasDecoder;
130 
131     // Optional Layer Norm params
132     std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
133     std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
134     std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
135     std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
136 
137     // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
138     std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
139     std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
140     std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
141 
142     // Int16 vectors for internal state data (to be decoded/encoded)
143     const uint32_t stateTensorSize = numBatches * numUnits;
144     std::vector<int16_t> inputGateData(stateTensorSize);
145     std::vector<int16_t> cellGateData(stateTensorSize);
146     std::vector<int16_t> forgetGateData(stateTensorSize);
147     std::vector<int16_t> outputGateData(stateTensorSize);
148     std::vector<int32_t> hiddenStateData(stateTensorSize);
149     std::vector<int16_t> outputInt16Data(numBatches * outputSize);
150 
151     armnn::TensorInfo inputGateInfo(
152             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
153     armnn::TensorInfo cellGateInfo(
154             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
155     armnn::TensorInfo forgetGateInfo(
156             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
157     armnn::TensorInfo outputGateInfo(
158             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
159     armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
160                                       armnn::DataType::QAsymmS8,
161                                       m_Data.m_Parameters.m_HiddenStateScale,
162                                       m_Data.m_Parameters.m_HiddenStateZeroPoint);
163     armnn::TensorInfo outputInt16Info({numBatches , outputSize},
164                                       armnn::DataType::QSymmS16,
165                                       outputInfo.GetQuantizationScale(),
166                                       outputInfo.GetQuantizationOffset());
167 
168     // Decoders/Encoders for internal states
169     std::unique_ptr<Decoder<float>> inputGateDecoder =
170             MakeDecoder<float>(inputGateInfo, inputGateData.data());
171     std::unique_ptr<Decoder<float>> cellGateDecoder =
172             MakeDecoder<float>(cellGateInfo, cellGateData.data());
173     std::unique_ptr<Decoder<float>> forgetGateDecoder =
174             MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
175     std::unique_ptr<Decoder<float>> outputGateDecoder =
176             MakeDecoder<float>(outputGateInfo, outputGateData.data());
177     std::unique_ptr<Decoder<float>> hiddenStateDecoder =
178             MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
179 
180     std::unique_ptr<Encoder<float>> inputGateEncoder =
181             MakeEncoder<float>(inputGateInfo, inputGateData.data());
182     std::unique_ptr<Encoder<float>> cellGateEncoder =
183             MakeEncoder<float>(cellGateInfo, cellGateData.data());
184     std::unique_ptr<Encoder<float>> forgetGateEncoder =
185             MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
186     std::unique_ptr<Encoder<float>> outputGateEncoder =
187             MakeEncoder<float>(outputGateInfo, outputGateData.data());
188     std::unique_ptr<Encoder<float>> hiddenStateEncoder =
189             MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
190 
191     // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
192     std::unique_ptr<Decoder<float>> outputInt16Decoder =
193             MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
194     std::unique_ptr<Encoder<float>> outputInt16Encoder =
195             MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
196 
197     // Create decoders for optional params if they are enabled
198     if (!cifgEnabled)
199     {
200         inputToInputWeightsDecoder = MakeDecoder<float>(
201                 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
202         recurrentToInputWeightsDecoder = MakeDecoder<float>(
203                 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
204     }
205 
206     if (peepholeEnabled)
207     {
208         if (!cifgEnabled)
209         {
210             cellToInputWeightsDecoder = MakeDecoder<float>(
211                     m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
212         }
213         cellToForgetWeightsDecoder = MakeDecoder<float>(
214                 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
215         cellToOutputWeightsDecoder = MakeDecoder<float>(
216                 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
217     }
218 
219     if (projectionEnabled)
220     {
221         projectionWeightsDecoder = MakeDecoder<float>(
222                 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
223         if (m_ProjectionBiasTensor)
224         {
225             projectionBiasDecoder = MakeDecoder<float>(
226                     m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
227         }
228     }
229 
230     if (layerNormEnabled)
231     {
232         if (!cifgEnabled)
233         {
234             inputLayerNormWeightsDecoder = MakeDecoder<float>(
235                     m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>());
236 
237             // Bias only used if layer norm enabled
238             armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
239                     m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
240             inputGateBiasDecoder = MakeDecoder<float>(
241                     inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>());
242         }
243 
244         forgetLayerNormWeightsDecoder = MakeDecoder<float>(
245                 m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>());
246         cellLayerNormWeightsDecoder = MakeDecoder<float>(
247                 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>());
248         outputLayerNormWeightsDecoder = MakeDecoder<float>(
249                 m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>());
250 
251         // Bias only used if layer norm enabled
252         armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
253                 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254         forgetGateBiasDecoder = MakeDecoder<float>(
255                 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>());
256 
257         armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
258                 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
259         cellGateBiasDecoder = MakeDecoder<float>(
260                 cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>());
261 
262         armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
263                 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
264         outputGateBiasDecoder = MakeDecoder<float>(
265                 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>());
266     }
267 
268     // Initialize internal state tensors with zeroes.
269     if (!cifgEnabled)
270     {
271         ZeroVector(*inputGateEncoder, stateTensorSize);
272     }
273     ZeroVector(*forgetGateEncoder, stateTensorSize);
274     ZeroVector(*cellGateEncoder, stateTensorSize);
275     ZeroVector(*outputGateEncoder, stateTensorSize);
276     ZeroVector(*hiddenStateEncoder, stateTensorSize);
277 
278     // Input weights * Input
279     if (!cifgEnabled)
280     {
281         MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
282                                             numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
283     }
284 
285     MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
286                                         numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
287 
288     MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
289                                         numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
290 
291     MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
292                                         numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
293 
294     // Recurrent weights * OutputStateIn
295     if (!cifgEnabled)
296     {
297         MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
298                                             numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
299     }
300 
301     MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
302                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
303 
304     MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
305                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
306 
307     MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
308                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
309 
310     // Input gate.
311     if (!cifgEnabled)
312     {
313         if (peepholeEnabled)
314         {
315             VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
316                                                     numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
317         }
318 
319         if (layerNormEnabled)
320         {
321             inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
322                                                m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
323                                                1024);
324             inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
325 
326             MeanStddevNormalization(*inputGateDecoder,
327                                     *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
328 
329             inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
330 
331             VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
332                                           numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
333 
334             inputGateInfo.SetQuantizationScale(1.f / 4096);
335             inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
336 
337             VectorBatchVectorAdd(*inputGateBiasDecoder,
338                                  numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
339 
340             inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
341         }
342 
343         inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
344         inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
345 
346         // Input gate sigmoid
347         Activation(*inputGateDecoder, *inputGateEncoder,
348                    TensorInfo({numUnits, numBatches}, internalType),
349                    ActivationFunction::Sigmoid, 0, 0);
350 
351         inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
352     }
353 
354     // Forget gate
355     if (peepholeEnabled)
356     {
357         VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
358                                                 *cellStateInDecoder, numBatches, *forgetGateEncoder);
359     }
360 
361     if (layerNormEnabled)
362     {
363         // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
364         forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
365                                             m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
366                                             1024);
367         forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
368 
369 
370 
371         MeanStddevNormalization(*forgetGateDecoder,
372                                 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
373 
374 
375         forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
376 
377         VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
378                                       numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
379 
380 
381         // Dequantize layer norm output to (1 / 4096)
382         forgetGateInfo.SetQuantizationScale(1.f / 4096);
383         forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
384 
385         VectorBatchVectorAdd(*forgetGateBiasDecoder,
386                              numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
387 
388 
389         forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
390     }
391 
392     forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
393     forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
394 
395     // Forget gate sigmoid
396     Activation(*forgetGateDecoder, *forgetGateEncoder,
397                TensorInfo({numUnits, numBatches}, internalType),
398                ActivationFunction::Sigmoid, 0, 0);
399 
400     forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
401 
402     // Cell (Modulation) gate
403     if (layerNormEnabled)
404     {
405         cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
406                                           m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
407                                           1024);
408         cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
409 
410         MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
411 
412         cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
413 
414         VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
415                                       numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
416 
417         cellGateInfo.SetQuantizationScale(1.f / 4096);
418         cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
419 
420         VectorBatchVectorAdd(*cellGateBiasDecoder,
421                              numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
422 
423         cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
424     }
425 
426     cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
427     cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
428 
429     // Cell (Modulation) gate tanH
430     Activation(*cellGateDecoder, *cellGateEncoder,
431                TensorInfo({numUnits, numBatches}, internalType),
432                ActivationFunction::TanH, 1.0f, 1.0f);
433 
434     cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
435 
436     VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
437 
438     if (cifgEnabled)
439     {
440         Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
441         VectorVectorCwiseProductAccumulate(
442                 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
443     }
444     else
445     {
446         VectorVectorCwiseProductAccumulate(
447                 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
448     }
449 
450     // Final cell state out calculated here
451     if (m_Data.m_Parameters.m_CellClip > 0.0)
452     {
453         ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
454     }
455 
456     // Output gate.
457     if (peepholeEnabled)
458     {
459         VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
460                                                 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
461     }
462 
463     if (layerNormEnabled)
464     {
465         outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
466                                             m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
467                                             1024);
468         outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
469 
470         MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
471 
472         outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
473 
474         VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
475                                       numBatches, *outputGateEncoder);
476 
477         outputGateInfo.SetQuantizationScale(1.f / 4096);
478         outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
479 
480         VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
481 
482         outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
483     }
484 
485     outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
486     outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
487 
488     // Output gate sigmoid
489     Activation(*outputGateDecoder, *outputGateEncoder,
490                TensorInfo({numUnits, numBatches}, internalType),
491                ActivationFunction::Sigmoid, 0, 0);
492 
493     outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
494 
495     // Hidden state tanH
496     Activation(*cellStateOutDecoder, *cellGateEncoder,
497                TensorInfo({numUnits, numBatches}, internalType),
498                ActivationFunction::TanH, 1.0f, 1.0f);
499 
500     // Final hidden state output
501     VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
502 
503     // Projection
504     if (m_Data.m_Parameters.m_ProjectionEnabled)
505     {
506         if (m_ProjectionBiasTensor)
507         {
508             VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
509         }
510 
511         MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
512                                             numBatches, *outputInt16Encoder);
513 
514         CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
515 
516         if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
517         {
518             ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
519         }
520     }
521     else
522     {
523         // Output has same quantization scale as hidden state if projection is disabled
524         CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
525     }
526 
527     // output == outputStateOut
528     CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
529 }
530 
531 } //namespace armnn
532