• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "LstmTestImpl.hpp"
7 
8 #include <armnnUtils/QuantizeHelper.hpp>
9 
10 #include <armnn/utility/NumericCast.hpp>
11 
12 #include <armnn/backends/TensorHandle.hpp>
13 
14 #include <armnnTestUtils/TensorCopyUtils.hpp>
15 #include <armnnTestUtils/WorkloadTestUtils.hpp>
16 
17 #include <reference/workloads/Decoders.hpp>
18 #include <reference/workloads/Encoders.hpp>
19 #include <reference/workloads/LstmUtils.hpp>
20 
21 #include <armnnTestUtils/TensorHelpers.hpp>
22 
23 #include <doctest/doctest.h>
24 namespace
25 {
26 
27 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LstmUtilsVectorBatchVectorAddTestImpl(std::vector<float> & vec,std::vector<float> & batchVec,uint32_t vSize,uint32_t nBatch,std::vector<float> & expectedOutput,armnn::TensorShape & expectedShape)28 void LstmUtilsVectorBatchVectorAddTestImpl(
29         std::vector<float>& vec,
30         std::vector<float>& batchVec,
31         uint32_t vSize,
32         uint32_t nBatch,
33         std::vector<float>& expectedOutput,
34         armnn::TensorShape& expectedShape)
35 {
36     float qScale = 1.0f;
37     int32_t qOffset = 0;
38     armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType,  qScale, qOffset );
39 
40     // Make encoder and decoder
41     std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
42     std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
43     std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
44 
45     VectorBatchVectorAdd(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
46 
47     // check shape and compare values
48     auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
49     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
50 
51     // check if iterator is back at start position
52     batchVecEncoder->Set(1.0f);
53     CHECK(batchVec[0] == 1.0f);
54 }
55 
56 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LstmUtilsZeroVectorTestImpl(std::vector<float> & input,uint32_t vSize,std::vector<float> & expectedOutput,armnn::TensorShape & expectedShape)57 void LstmUtilsZeroVectorTestImpl(
58         std::vector<float>& input,
59         uint32_t vSize,
60         std::vector<float>& expectedOutput,
61         armnn::TensorShape& expectedShape)
62 {
63     float qScale = 1.0f;
64     int32_t qOffset = 0;
65 
66     armnn::TensorInfo tensorInfo({vSize}, ArmnnType,  qScale, qOffset );
67 
68     // Make encoder for input
69     std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
70 
71     // call ZeroVector
72     ZeroVector(*outputEncoder, vSize);
73 
74     // check shape and compare values
75     auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
76     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
77 
78     // check if iterator is back at start position
79     outputEncoder->Set(1.0f);
80     CHECK(input[0] == 1.0f);
81 
82 }
83 
84 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LstmUtilsMeanStddevNormalizationTestImpl(std::vector<float> & input,uint32_t vSize,uint32_t nBatch,std::vector<float> & expectedOutput,armnn::TensorShape & expectedShape)85 void LstmUtilsMeanStddevNormalizationTestImpl(
86         std::vector<float>& input,
87         uint32_t vSize,
88         uint32_t nBatch,
89         std::vector<float>& expectedOutput,
90         armnn::TensorShape& expectedShape)
91 {
92     float qScale = 1.0f;
93     int32_t qOffset = 0;
94     armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType,  qScale, qOffset );
95 
96     // Make encoder and decoder for input
97     std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
98     std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
99 
100     MeanStddevNormalization(*inputDecoder, *outputEncoder, vSize, nBatch, 1e-8f);
101 
102     // check shape and compare values
103     auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
104     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
105 
106     // check if iterator is back at start position
107     outputEncoder->Set(1.0f);
108     CHECK(input[0] == 1.0f);
109 }
110 
111 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LstmUtilsVectorBatchVectorCwiseProductTestImpl(std::vector<float> & vec,std::vector<float> & batchVec,uint32_t vSize,uint32_t nBatch,std::vector<float> & expectedOutput,armnn::TensorShape & expectedShape)112 void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
113         std::vector<float>& vec,
114         std::vector<float>& batchVec,
115         uint32_t vSize,
116         uint32_t nBatch,
117         std::vector<float>& expectedOutput,
118         armnn::TensorShape& expectedShape)
119 {
120     float qScale = 1.0f;
121     int32_t qOffset = 0;
122     armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType,  qScale, qOffset );
123 
124     // Make encoder and decoder
125     std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
126     std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
127     std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
128 
129     VectorBatchVectorCwiseProduct(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
130 
131     // check shape and compare values
132     auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
133     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
134 
135     // check if iterator is back at start position
136     batchVecEncoder->Set(1.0f);
137     CHECK(batchVec[0] == 1.0f);
138 }
139 
140 // Lstm Layer tests:
141 // *********************************** //
142 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
143 LayerTestResult<T, 2>
LstmNoCifgNoPeepholeNoProjectionTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<T> & input,const std::vector<T> & outputExpected,const armnn::TensorShape & inputShape,const armnn::TensorShape & outputExpectedShape,float qScale=1.0f,int32_t qOffset=0,armnn::DataType constantDataType=armnn::DataType::Float32)144 LstmNoCifgNoPeepholeNoProjectionTestImpl(
145         armnn::IWorkloadFactory& workloadFactory,
146         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
147         const armnn::ITensorHandleFactory& tensorHandleFactory,
148         const std::vector<T>& input,
149         const std::vector<T>& outputExpected,
150         const armnn::TensorShape& inputShape,
151         const armnn::TensorShape& outputExpectedShape,
152         float qScale = 1.0f,
153         int32_t qOffset = 0,
154         armnn::DataType constantDataType = armnn::DataType::Float32)
155 {
156     IgnoreUnused(memoryManager);
157     unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
158     unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
159     unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
160     // cellSize and outputSize have the same size when there is no projection.
161     unsigned numUnits = outputSize;
162 
163     armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType,  qScale, qOffset );
164     armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
165     armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
166 
167     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
168     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
169     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
170     armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
171 
172     std::vector<T> inputVector;
173     inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
174 
175     std::vector<T> cellStateInVector(batchSize * numUnits, T());
176     std::vector<T> outputStateInVector(batchSize * outputSize, T());
177     std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
178     std::vector<T> outputStateOutVector(batchSize * outputSize, T());
179     std::vector<T> cellStateOutVector(batchSize * numUnits, T());
180 
181     std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
182 
183     std::vector<T> outputVector;
184     outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
185 
186     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
187     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
188             tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
189     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
190             tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
191 
192     std::unique_ptr<armnn::ITensorHandle> scratchHandle =
193             tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
194     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
195             tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
196     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
197             tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
198     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
199 
200     armnn::LstmQueueDescriptor data;
201     armnn::WorkloadInfo info;
202 
203     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
204     AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
205     AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
206 
207     AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
208     AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
209     AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
210     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
211 
212     armnn::TensorInfo tensorInfo4({numUnits}, constantDataType , qScale, qOffset);
213     armnn::TensorInfo tensorInfo8({numUnits, 2}, constantDataType, qScale, qOffset);
214     armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
215 
216     std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
217                                               -0.34550029f, 0.04266912f, -0.15680569f,
218                                               -0.34856534f, 0.43890524f};
219 
220     std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
221                                                 -0.31343272f, -0.40032279f, 0.44781327f,
222                                                 0.01387155f, -0.35593212f};
223 
224     std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
225                                               -0.20583314f, 0.44344562f, 0.22077113f,
226                                               -0.29909778f};
227 
228     std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
229                                                 0.40525138f, 0.44272184f, 0.03897077f,
230                                                 -0.1556896f, 0.19487578f};
231 
232     std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
233                                                   -0.35746509f, 0.28902304f, 0.08183324f,
234                                                   -0.16555229f, 0.02286911f, -0.13566875f,
235                                                   0.03034258f, 0.48091322f, -0.12528998f,
236                                                   0.24077177f, -0.51332325f, -0.33502164f,
237                                                   0.10629296f};
238 
239     std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
240                                                     0.2112639f, 0.27654213f, 0.20864892f,
241                                                     -0.07646349f, 0.45877004f, 0.00141793f,
242                                                     -0.14609534f, 0.36447752f, 0.09196436f,
243                                                     0.28053468f, 0.01560611f, -0.20127171f,
244                                                     -0.01140004f};
245 
246     std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
247                                                   0.26320225f, 0.05695659f, -0.00123841f,
248                                                   -0.4744786f, -0.35869038f, -0.06418842f,
249                                                   -0.13502428f, -0.501764f, 0.22830659f,
250                                                   -0.46367589f, 0.26016325f, -0.03894562f,
251                                                   -0.16368064f};
252 
253     std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
254                                                     0.09215671f, 0.24107647f, -0.39835793f,
255                                                     0.18212086f, 0.01301402f, 0.48572797f,
256                                                     -0.50656658f, 0.20047462f, -0.20607421f,
257                                                     -0.51818722f, -0.15390486f, 0.0468148f,
258                                                     0.39922136f};
259 
260     std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
261 
262     std::vector<float> inputGateBias = {0., 0., 0., 0.};
263 
264     std::vector<float> forgetGateBias = {1., 1., 1., 1.};
265 
266     std::vector<float> cellBias = {0., 0., 0., 0.};
267 
268     std::vector<float> outputGateBias = {0., 0., 0., 0.};
269 
270     armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo8);
271     armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo8);
272     armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo8);
273     armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo8);
274     armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16);
275     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16);
276     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16);
277     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16);
278     armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
279     armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
280     armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
281     armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
282     armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
283 
284     AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
285     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
286     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
287     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
288     AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
289     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
290     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
291     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
292     AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
293     AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
294     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
295     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
296     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
297 
298     data.m_InputToInputWeights = &inputToInputWeightsTensor;
299     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
300     data.m_InputToCellWeights = &inputToCellWeightsTensor;
301     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
302     data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
303     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
304     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
305     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
306     data.m_InputGateBias = &inputGateBiasTensor;
307     data.m_ForgetGateBias = &forgetGateBiasTensor;
308     data.m_CellBias = &cellBiasTensor;
309     data.m_OutputGateBias = &outputGateBiasTensor;
310 
311     // Flags to set test configuration
312     data.m_Parameters.m_ActivationFunc = 4;
313     data.m_Parameters.m_CifgEnabled = false;
314     data.m_Parameters.m_PeepholeEnabled = false;
315     data.m_Parameters.m_ProjectionEnabled = false;
316 
317     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
318     inputHandle->Allocate();
319     outputStateInHandle->Allocate();
320     cellStateInHandle->Allocate();
321 
322     scratchHandle->Allocate();
323     outputStateOutHandle->Allocate();
324     cellStateOutHandle->Allocate();
325     outputHandle->Allocate();
326 
327     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
328     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
329     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
330 
331     workload->Execute();
332 
333     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
334 
335     return LayerTestResult<T, 2>(actualOutput,
336                                  outputVector,
337                                  outputHandle->GetShape(),
338                                  outputTensorInfo.GetShape());
339 }
340 
341 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
342 LayerTestResult<T, 2>
LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<T> & input,const std::vector<T> & outputExpected,float qScale=1.0f,int32_t qOffset=0,armnn::DataType constantDataType=armnn::DataType::Float32)343 LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workloadFactory,
344                                                   const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
345                                                   const armnn::ITensorHandleFactory& tensorHandleFactory,
346                                                   const std::vector<T>& input,
347                                                   const std::vector<T>& outputExpected,
348                                                   float qScale = 1.0f,
349                                                   int32_t qOffset = 0,
350                                                   armnn::DataType constantDataType = armnn::DataType::Float32)
351 {
352     IgnoreUnused(memoryManager);
353     unsigned int batchSize = 2;
354     unsigned int outputSize = 16;
355     unsigned int inputSize = 5;
356     unsigned numUnits = 20;
357 
358     armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
359     armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
360     armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
361 
362     // Scratch buffer size without CIFG [batchSize, numUnits * 4]
363     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
364     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
365     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
366     armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
367 
368     std::vector<T> inputVector;
369     inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
370 
371     std::vector<T> cellStateInVector(batchSize * numUnits, T());
372     std::vector<T> outputStateInVector(batchSize * outputSize, T());
373     std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
374     std::vector<T> outputStateOutVector(batchSize * outputSize, T());
375     std::vector<T> cellStateOutVector(batchSize * numUnits, T());
376 
377     std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
378 
379     std::vector<T> outputVector;
380     outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
381 
382     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
383     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
384             tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
385     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
386             tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
387 
388     std::unique_ptr<armnn::ITensorHandle> scratchHandle =
389             tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
390     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
391             tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
392     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
393             tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
394     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
395 
396     armnn::LstmQueueDescriptor data;
397     armnn::WorkloadInfo info;
398 
399     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
400     AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
401     AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
402 
403     AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
404     AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
405     AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
406     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
407 
408     armnn::TensorInfo tensorInfo16({outputSize}, constantDataType, qScale, qOffset);
409     armnn::TensorInfo tensorInfo20({numUnits}, constantDataType, qScale, qOffset);
410     armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
411     armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
412     armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
413 
414     std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f,  0.046905167f,-0.014657677f,-0.03149463f,
415                                               0.09171803f, 0.14647801f,0.10797193f,   -0.0057968358f,0.0019193048f,
416                                               -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
417                                               -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
418                                               -0.008045952f,0.015478081f, 0.055217247f,  0.038719587f, 0.044153627f,
419                                               -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
420                                               -0.1671009f,   -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
421                                               0.25005487f, -0.22790983f, 0.009855087f,  -0.028140958f, -0.11200698f,
422                                               0.11295408f, -0.0035217577f, 0.054485075f,  0.05184695f, 0.064711206f,
423                                               0.10989193f,   0.11674786f,  0.03490607f, 0.07727357f, 0.11390585f,
424                                               -0.1863375f,  -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
425                                               0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f,    0.14545603f,
426                                               -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
427                                               -0.042484224f, -0.11827596f, -0.09171104f,  -0.10808628f,-0.16327988f,
428                                               -0.2273378f,   -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
429                                               0.0038534778f, 0.054764505f,   0.089753784f, 0.06947234f, 0.08014476f,
430                                               -0.04544234f, -0.0497073f,-0.07135631f,  -0.048929106f,-0.004042012f,
431                                               -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
432                                               -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
433                                               -0.39292613f,  -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
434 
435     std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
436                                                -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
437                                                -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
438                                                0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
439                                                0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
440                                                -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
441                                                -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
442                                                0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
443                                                0.06958324f,    0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
444                                                0.12784666f,    0.07077897f,  0.025725935f, 0.04165009f,0.07241905f,
445                                                0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
446                                                -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
447                                                0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
448                                                -0.08402166f,-0.01901462f,  -0.044678304f,-0.07720565f,0.014350063f,
449                                                -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
450                                                0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
451                                                0.036881298f,   0.02913376f,  0.03420159f,0.05448447f,-0.054523353f,
452                                                0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
453                                                -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
454                                                0.0001771948f,  -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
455 
456     std::vector<float> inputToCellWeights = { -0.04580283f,   -0.09549462f,   -0.032418985f,  -0.06454633f,
457                                               -0.043528453f,  0.043018587f,   -0.049152344f,  -0.12418144f,
458                                               -0.078985475f,  -0.07596889f,   0.019484362f,   -0.11434962f,
459                                               -0.0074034138f, -0.06314844f,   -0.092981495f,  0.0062155537f,
460                                               -0.025034338f,  -0.0028890965f, 0.048929527f,   0.06235075f,
461                                               0.10665918f,    -0.032036792f,  -0.08505916f,   -0.10843358f,
462                                               -0.13002433f,   -0.036816437f,  -0.02130134f,   -0.016518239f,
463                                               0.0047691227f,  -0.0025825808f, 0.066017866f,   0.029991534f,
464                                               -0.10652836f,   -0.1037554f,    -0.13056071f,   -0.03266643f,
465                                               -0.033702414f,  -0.006473424f,  -0.04611692f,   0.014419339f,
466                                               -0.025174323f,  0.0396852f,     0.081777506f,   0.06157468f,
467                                               0.10210095f,    -0.009658194f,  0.046511717f,   0.03603906f,
468                                               0.0069369148f,  0.015960095f,   -0.06507666f,   0.09551598f,
469                                               0.053568836f,   0.06408714f,    0.12835667f,    -0.008714329f,
470                                               -0.20211966f,   -0.12093674f,   0.029450472f,   0.2849013f,
471                                               -0.029227901f,  0.1164364f,     -0.08560263f,   0.09941786f,
472                                               -0.036999565f,  -0.028842626f,  -0.0033637602f, -0.017012902f,
473                                               -0.09720865f,   -0.11193351f,   -0.029155117f,  -0.017936034f,
474                                               -0.009768936f,  -0.04223324f,   -0.036159635f,  0.06505112f,
475                                               -0.021742892f,  -0.023377212f,  -0.07221364f,   -0.06430552f,
476                                               0.05453865f,    0.091149814f,   0.06387331f,    0.007518393f,
477                                               0.055960953f,   0.069779344f,   0.046411168f,   0.10509911f,
478                                               0.07463894f,    0.0075130584f,  0.012850982f,   0.04555431f,
479                                               0.056955688f,   0.06555285f,    0.050801456f,   -0.009862683f,
480                                               0.00826772f,    -0.026555609f,  -0.0073611983f, -0.0014897042f };
481 
482     std::vector<float> inputToOutputWeights ={-0.0998932f,   -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
483                                               -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f,  -0.15093534f,
484                                               0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
485                                               -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
486                                               -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
487                                               0.10124236f,  0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
488                                               -0.027833903f, 0.029774971f,  0.1130802f, 0.09218906f, 0.09506135f,
489                                               -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
490                                               -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
491                                               -0.11366429f,  0.035777505f,  0.13568819f, 0.052451383f,0.050649304f,
492                                               0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
493                                               0.04974699f, 0.014160473f,  0.06973932f,    0.04964942f, 0.033364646f,
494                                               0.08190124f,   0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
495                                               -0.078907564f,-0.06707616f,  -0.11844508f, -0.09986688f,-0.07509403f,
496                                               0.06263226f,   0.14925587f,   0.20188436f, 0.12098451f,0.14639415f,
497                                               0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
498                                               -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f,  0.021544158f,
499                                               0.08949725f,  0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
500                                               -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
501                                               -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
502 
503     std::vector<float> inputGateBias = {0.02234832f,  0.14757581f,   0.18176508f,  0.10380666f,  0.053110216f,
504                                         -0.06928846f, -0.13942584f,  -0.11816189f, 0.19483899f,  0.03652339f,
505                                         -0.10250295f, 0.036714908f,  -0.18426876f, 0.036065217f, 0.21810818f,
506                                         0.02383196f,  -0.043370757f, 0.08690144f,  -0.04444982f, 0.00030581196f };
507 
508     std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
509                                         0.11098921f,  0.15378423f,   0.09263801f,  0.09790885f,
510                                         0.09508917f,  0.061199076f,  0.07665568f,  -0.015443159f,
511                                         -0.03499149f, 0.046190713f,  0.08895977f,  0.10899629f,
512                                         0.40694186f,  0.06030037f,   0.012413437f, -0.06108739f };
513 
514     std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f,   0.033463873f,
515                                     -0.1483596f,   -0.10639995f,  -0.091433935f, 0.058573797f,
516                                     -0.06809782f,  -0.07889636f,  -0.043246906f, -0.09829136f,
517                                     -0.4279842f,   0.034901652f,  0.18797937f,   0.0075234566f,
518                                     0.016178843f,  0.1749513f,    0.13975595f,   0.92058027f };
519 
520     std::vector<float> outputGateBias ={0.046159424f,  -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
521                                         0.35373217f,   -0.018957434f,  0.008907322f, -0.0762701f, 0.12018895f,
522                                         0.04216877f,   0.0022856654f,  0.040952638f,  0.3147856f,  0.08225149f,
523                                         -0.057416286f, -0.14995944f,   -0.008040261f, 0.13208859f, 0.029760877f};
524 
525     std::vector<float> recurrentToInputWeights = { -0.001374326f,   -0.078856036f,   0.10672688f,    0.029162422f,
526                                                    -0.11585556f,    0.02557986f,     -0.13446963f,   -0.035785314f,
527                                                    -0.01244275f,    0.025961924f,    -0.02337298f,   -0.044228926f,
528                                                    -0.055839065f,   -0.046598054f,   -0.010546039f,  -0.06900766f,
529                                                    0.027239809f,    0.022582639f,    -0.013296484f,  -0.05459212f,
530                                                    0.08981f,        -0.045407712f,   0.08682226f,    -0.06867011f,
531                                                    -0.14390695f,    -0.02916037f,    0.000996957f,   0.091420636f,
532                                                    0.14283475f,     -0.07390571f,    -0.06402044f,   0.062524505f,
533                                                    -0.093129106f,   0.04860203f,     -0.08364217f,   -0.08119002f,
534                                                    0.009352075f,    0.22920375f,     0.0016303885f,  0.11583097f,
535                                                    -0.13732095f,    0.012405723f,    -0.07551853f,   0.06343048f,
536                                                    0.12162708f,     -0.031923793f,   -0.014335606f,  0.01790974f,
537                                                    -0.10650317f,    -0.0724401f,     0.08554849f,    -0.05727212f,
538                                                    0.06556731f,     -0.042729504f,   -0.043227166f,  0.011683251f,
539                                                    -0.013082158f,   -0.029302018f,   -0.010899579f,  -0.062036745f,
540                                                    -0.022509435f,   -0.00964907f,    -0.01567329f,   0.04260106f,
541                                                    -0.07787477f,    -0.11576462f,    0.017356863f,   0.048673786f,
542                                                    -0.017577527f,   -0.05527947f,    -0.082487635f,  -0.040137455f,
543                                                    -0.10820036f,    -0.04666372f,    0.022746278f,   -0.07851417f,
544                                                    0.01068115f,     0.032956902f,    0.022433773f,   0.0026891115f,
545                                                    0.08944216f,     -0.0685835f,     0.010513544f,   0.07228705f,
546                                                    0.02032331f,     -0.059686817f,   -0.0005566496f, -0.086984694f,
547                                                    0.040414046f,    -0.1380399f,     0.094208956f,   -0.05722982f,
548                                                    0.012092817f,    -0.04989123f,    -0.086576f,     -0.003399834f,
549                                                    -0.04696032f,    -0.045747425f,   0.10091314f,    0.048676282f,
550                                                    -0.029037097f,   0.031399418f,    -0.0040285117f, 0.047237843f,
551                                                    0.09504992f,     0.041799378f,    -0.049185462f,  -0.031518843f,
552                                                    -0.10516937f,    0.026374253f,    0.10058866f,    -0.0033195973f,
553                                                    -0.041975245f,   0.0073591834f,   0.0033782164f,  -0.004325073f,
554                                                    -0.10167381f,    0.042500053f,    -0.01447153f,   0.06464186f,
555                                                    -0.017142897f,   0.03312627f,     0.009205989f,   0.024138335f,
556                                                    -0.011337001f,   0.035530265f,    -0.010912711f,  0.0706555f,
557                                                    -0.005894094f,   0.051841937f,    -0.1401738f,    -0.02351249f,
558                                                    0.0365468f,      0.07590991f,     0.08838724f,    0.021681072f,
559                                                    -0.10086113f,    0.019608743f,    -0.06195883f,   0.077335775f,
560                                                    0.023646897f,    -0.095322326f,   0.02233014f,    0.09756986f,
561                                                    -0.048691444f,   -0.009579111f,   0.07595467f,    0.11480546f,
562                                                    -0.09801813f,    0.019894179f,    0.08502348f,    0.004032281f,
563                                                    0.037211012f,    0.068537936f,    -0.048005626f,  -0.091520436f,
564                                                    -0.028379958f,   -0.01556313f,    0.06554592f,    -0.045599163f,
565                                                    -0.01672207f,    -0.020169014f,   -0.011877351f,  -0.20212261f,
566                                                    0.010889619f,    0.0047078193f,   0.038385306f,   0.08540671f,
567                                                    -0.017140968f,   -0.0035865551f,  0.016678626f,   0.005633034f,
568                                                    0.015963363f,    0.00871737f,     0.060130805f,   0.028611384f,
569                                                    0.10109069f,     -0.015060172f,   -0.07894427f,   0.06401885f,
570                                                    0.011584063f,    -0.024466386f,   0.0047652307f,  -0.09041358f,
571                                                    0.030737216f,    -0.0046374933f,  0.14215417f,    -0.11823516f,
572                                                    0.019899689f,    0.006106124f,    -0.027092824f,  0.0786356f,
573                                                    0.05052217f,     -0.058925f,      -0.011402121f,  -0.024987547f,
574                                                    -0.0013661642f,  -0.06832946f,    -0.015667673f,  -0.1083353f,
575                                                    -0.00096863037f, -0.06988685f,    -0.053350925f,  -0.027275559f,
576                                                    -0.033664223f,   -0.07978348f,    -0.025200296f,  -0.017207067f,
577                                                    -0.058403496f,   -0.055697463f,   0.005798788f,   0.12965427f,
578                                                    -0.062582195f,   0.0013350133f,   -0.10482091f,   0.0379771f,
579                                                    0.072521195f,    -0.0029455067f,  -0.13797039f,   -0.03628521f,
580                                                    0.013806405f,    -0.017858358f,   -0.01008298f,   -0.07700066f,
581                                                    -0.017081132f,   0.019358726f,    0.0027079724f,  0.004635139f,
582                                                    0.062634714f,    -0.02338735f,    -0.039547626f,  -0.02050681f,
583                                                    0.03385117f,     -0.083611414f,   0.002862572f,   -0.09421313f,
584                                                    0.058618143f,    -0.08598433f,    0.00972939f,    0.023867095f,
585                                                    -0.053934585f,   -0.023203006f,   0.07452513f,    -0.048767887f,
586                                                    -0.07314807f,    -0.056307215f,   -0.10433547f,   -0.06440842f,
587                                                    0.04328182f,     0.04389765f,     -0.020006588f,  -0.09076438f,
588                                                    -0.11652589f,    -0.021705797f,   0.03345259f,    -0.010329105f,
589                                                    -0.025767034f,   0.013057034f,    -0.07316461f,   -0.10145612f,
590                                                    0.06358255f,     0.18531723f,     0.07759293f,    0.12006465f,
591                                                    0.1305557f,      0.058638252f,    -0.03393652f,   0.09622831f,
592                                                    -0.16253184f,    -2.4580743e-06f, 0.079869635f,   -0.070196845f,
593                                                    -0.005644518f,   0.06857898f,     -0.12598175f,   -0.035084512f,
594                                                    0.03156317f,     -0.12794146f,    -0.031963028f,  0.04692781f,
595                                                    0.030070418f,    0.0071660685f,   -0.095516115f,  -0.004643372f,
596                                                    0.040170413f,    -0.062104587f,   -0.0037324072f, 0.0554317f,
597                                                    0.08184801f,     -0.019164372f,   0.06791302f,    0.034257166f,
598                                                    -0.10307039f,    0.021943003f,    0.046745934f,   0.0790918f,
599                                                    -0.0265588f,     -0.007824208f,   0.042546265f,   -0.00977924f,
600                                                    -0.0002440307f,  -0.017384544f,   -0.017990116f,  0.12252321f,
601                                                    -0.014512694f,   -0.08251313f,    0.08861942f,    0.13589665f,
602                                                    0.026351685f,    0.012641483f,    0.07466548f,    0.044301085f,
603                                                    -0.045414884f,   -0.051112458f,   0.03444247f,    -0.08502782f,
604                                                    -0.04106223f,    -0.028126027f,   0.028473156f,   0.10467447f };
605 
606     std::vector<float> recurrentToForgetWeights = {-0.057784554f,  -0.026057621f,  -0.068447545f,   -0.022581743f,
607                                                    0.14811787f,    0.10826372f,    0.09471067f,     0.03987225f,
608                                                    -0.0039523416f, 0.00030638507f, 0.053185795f,    0.10572994f,
609                                                    0.08414449f,    -0.022036452f,  -0.00066928595f, -0.09203576f,
610                                                    0.032950465f,   -0.10985798f,   -0.023809856f,   0.0021431844f,
611                                                    -0.02196096f,   -0.00326074f,   0.00058621005f,  -0.074678116f,
612                                                    -0.06193199f,   0.055729095f,   0.03736828f,     0.020123724f,
613                                                    0.061878487f,   -0.04729229f,   0.034919553f,    -0.07585433f,
614                                                    -0.04421272f,   -0.044019096f,  0.085488975f,    0.04058006f,
615                                                    -0.06890133f,   -0.030951202f,  -0.024628663f,   -0.07672815f,
616                                                    0.034293607f,   0.08556707f,    -0.05293577f,    -0.033561368f,
617                                                    -0.04899627f,   0.0241671f,     0.015736353f,    -0.095442444f,
618                                                    -0.029564252f,  0.016493602f,   -0.035026584f,   0.022337519f,
619                                                    -0.026871363f,  0.004780428f,   0.0077918363f,   -0.03601621f,
620                                                    0.016435321f,   -0.03263031f,   -0.09543275f,    -0.047392778f,
621                                                    0.013454138f,   0.028934088f,   0.01685226f,     -0.086110644f,
622                                                    -0.046250615f,  -0.01847454f,   0.047608484f,    0.07339695f,
623                                                    0.034546845f,   -0.04881143f,   0.009128804f,    -0.08802852f,
624                                                    0.03761666f,    0.008096139f,   -0.014454086f,   0.014361001f,
625                                                    -0.023502491f,  -0.0011840804f, -0.07607001f,    0.001856849f,
626                                                    -0.06509276f,   -0.006021153f,  -0.08570962f,    -0.1451793f,
627                                                    0.060212336f,   0.055259194f,   0.06974018f,     0.049454916f,
628                                                    -0.027794661f,  -0.08077226f,   -0.016179763f,   0.1169753f,
629                                                    0.17213494f,    -0.0056326236f, -0.053934924f,   -0.0124349f,
630                                                    -0.11520337f,   0.05409887f,    0.088759385f,    0.0019655675f,
631                                                    0.0042065294f,  0.03881498f,    0.019844765f,    0.041858196f,
632                                                    -0.05695512f,   0.047233116f,   0.038937137f,    -0.06542224f,
633                                                    0.014429736f,   -0.09719407f,   0.13908425f,     -0.05379757f,
634                                                    0.012321099f,   0.082840554f,   -0.029899208f,   0.044217527f,
635                                                    0.059855383f,   0.07711018f,    -0.045319796f,   0.0948846f,
636                                                    -0.011724666f,  -0.0033288454f, -0.033542685f,   -0.04764985f,
637                                                    -0.13873616f,   0.040668588f,   0.034832682f,    -0.015319203f,
638                                                    -0.018715994f,  0.046002675f,   0.0599172f,      -0.043107376f,
639                                                    0.0294216f,     -0.002314414f,  -0.022424703f,   0.0030315618f,
640                                                    0.0014641669f,  0.0029166266f,  -0.11878115f,    0.013738511f,
641                                                    0.12375372f,    -0.0006038222f, 0.029104086f,    0.087442465f,
642                                                    0.052958444f,   0.07558703f,    0.04817258f,     0.044462286f,
643                                                    -0.015213451f,  -0.08783778f,   -0.0561384f,     -0.003008196f,
644                                                    0.047060397f,   -0.002058388f,  0.03429439f,     -0.018839769f,
645                                                    0.024734668f,   0.024614193f,   -0.042046934f,   0.09597743f,
646                                                    -0.0043254104f, 0.04320769f,    0.0064070094f,   -0.0019131786f,
647                                                    -0.02558259f,   -0.022822596f,  -0.023273505f,   -0.02464396f,
648                                                    -0.10991725f,   -0.006240552f,  0.0074488563f,   0.024044557f,
649                                                    0.04383914f,    -0.046476185f,  0.028658995f,    0.060410924f,
650                                                    0.050786525f,   0.009452605f,   -0.0073054377f,  -0.024810238f,
651                                                    0.0052906186f,  0.0066939713f,  -0.0020913032f,  0.014515517f,
652                                                    0.015898481f,   0.021362653f,   -0.030262267f,   0.016587038f,
653                                                    -0.011442813f,  0.041154444f,   -0.007631438f,   -0.03423484f,
654                                                    -0.010977775f,  0.036152758f,   0.0066366293f,   0.11915515f,
655                                                    0.02318443f,    -0.041350313f,  0.021485701f,    -0.10906167f,
656                                                    -0.028218046f,  -0.00954771f,   0.020531068f,    -0.11995105f,
657                                                    -0.03672871f,   0.024019798f,   0.014255957f,    -0.05221243f,
658                                                    -0.00661567f,   -0.04630967f,   0.033188973f,    0.10107534f,
659                                                    -0.014027541f,  0.030796422f,   -0.10270911f,    -0.035999842f,
660                                                    0.15443139f,    0.07684145f,    0.036571592f,    -0.035900835f,
661                                                    -0.0034699554f, 0.06209149f,    0.015920248f,    -0.031122351f,
662                                                    -0.03858649f,   0.01849943f,    0.13872518f,     0.01503974f,
663                                                    0.069941424f,   -0.06948533f,   -0.0088794185f,  0.061282158f,
664                                                    -0.047401894f,  0.03100163f,    -0.041533746f,   -0.10430945f,
665                                                    0.044574402f,   -0.01425562f,   -0.024290353f,   0.034563623f,
666                                                    0.05866852f,    0.023947537f,   -0.09445152f,    0.035450947f,
667                                                    0.02247216f,    -0.0042998926f, 0.061146557f,    -0.10250651f,
668                                                    0.020881841f,   -0.06747029f,   0.10062043f,     -0.0023941975f,
669                                                    0.03532124f,    -0.016341697f,  0.09685456f,     -0.016764693f,
670                                                    0.051808182f,   0.05875331f,    -0.04536488f,    0.001626336f,
671                                                    -0.028892258f,  -0.01048663f,   -0.009793449f,   -0.017093895f,
672                                                    0.010987891f,   0.02357273f,    -0.00010856845f, 0.0099760275f,
673                                                    -0.001845119f,  -0.03551521f,   0.0018358806f,   0.05763657f,
674                                                    -0.01769146f,   0.040995963f,   0.02235177f,     -0.060430344f,
675                                                    0.11475477f,    -0.023854522f,  0.10071741f,     0.0686208f,
676                                                    -0.014250481f,  0.034261297f,   0.047418304f,    0.08562733f,
677                                                    -0.030519066f,  0.0060542435f,  0.014653856f,    -0.038836084f,
678                                                    0.04096551f,    0.032249358f,   -0.08355519f,    -0.026823482f,
679                                                    0.056386515f,   -0.010401743f,  -0.028396193f,   0.08507674f,
680                                                    0.014410365f,   0.020995233f,   0.17040324f,     0.11511526f,
681                                                    0.02459721f,    0.0066619175f,  0.025853224f,    -0.023133837f,
682                                                    -0.081302024f,  0.017264642f,   -0.009585969f,   0.09491168f,
683                                                    -0.051313367f,  0.054532815f,   -0.014298593f,   0.10657464f,
684                                                    0.007076659f,   0.10964551f,    0.0409152f,      0.008275321f,
685                                                    -0.07283536f,   0.07937492f,    0.04192024f,     -0.1075027f };
686 
687     std::vector<float> recurrentToCellWeights = { -0.037322544f,   0.018592842f,   0.0056175636f,  -0.06253426f,
688                                                    0.055647098f,    -0.05713207f,   -0.05626563f,   0.005559383f,
689                                                    0.03375411f,     -0.025757805f,  -0.088049285f,  0.06017052f,
690                                                    -0.06570978f,    0.007384076f,   0.035123326f,   -0.07920549f,
691                                                    0.053676967f,    0.044480428f,   -0.07663568f,   0.0071805613f,
692                                                    0.08089997f,     0.05143358f,    0.038261272f,   0.03339287f,
693                                                    -0.027673481f,   0.044746667f,   0.028349208f,   0.020090483f,
694                                                    -0.019443132f,   -0.030755889f,  -0.0040000007f, 0.04465846f,
695                                                    -0.021585021f,   0.0031670958f,  0.0053199246f,  -0.056117613f,
696                                                    -0.10893326f,    0.076739706f,   -0.08509834f,   -0.027997585f,
697                                                    0.037871376f,    0.01449768f,    -0.09002357f,   -0.06111149f,
698                                                    -0.046195522f,   0.0422062f,     -0.005683705f,  -0.1253618f,
699                                                    -0.012925729f,   -0.04890792f,   0.06985068f,    0.037654128f,
700                                                    0.03398274f,     -0.004781977f,  0.007032333f,   -0.031787455f,
701                                                    0.010868644f,    -0.031489216f,  0.09525667f,    0.013939797f,
702                                                    0.0058680447f,   0.0167067f,     0.02668468f,    -0.04797466f,
703                                                    -0.048885044f,   -0.12722108f,   0.035304096f,   0.06554885f,
704                                                    0.00972396f,     -0.039238118f,  -0.05159735f,   -0.11329045f,
705                                                    0.1613692f,      -0.03750952f,   0.06529313f,    -0.071974665f,
706                                                    -0.11769596f,    0.015524369f,   -0.0013754242f, -0.12446318f,
707                                                    0.02786344f,     -0.014179351f,  0.005264273f,   0.14376344f,
708                                                    0.015983658f,    0.03406988f,    -0.06939408f,   0.040699873f,
709                                                    0.02111075f,     0.09669095f,    0.041345075f,   -0.08316494f,
710                                                    -0.07684199f,    -0.045768797f,  0.032298047f,   -0.041805092f,
711                                                    0.0119405f,      0.0061010392f,  0.12652606f,    0.0064572375f,
712                                                    -0.024950314f,   0.11574242f,    0.04508852f,    -0.04335324f,
713                                                    0.06760663f,     -0.027437469f,  0.07216407f,    0.06977076f,
714                                                    -0.05438599f,    0.034033038f,   -0.028602652f,  0.05346137f,
715                                                    0.043184172f,    -0.037189785f,  0.10420091f,    0.00882477f,
716                                                    -0.054019816f,   -0.074273005f,  -0.030617684f,  -0.0028467078f,
717                                                    0.024302477f,    -0.0038869337f, 0.005332455f,   0.0013399826f,
718                                                    0.04361412f,     -0.007001822f,  0.09631092f,    -0.06702025f,
719                                                    -0.042049985f,   -0.035070654f,  -0.04103342f,   -0.10273396f,
720                                                    0.0544271f,      0.037184782f,   -0.13150354f,   -0.0058036847f,
721                                                    -0.008264958f,   0.042035464f,   0.05891794f,    0.029673764f,
722                                                    0.0063542654f,   0.044788733f,   0.054816857f,   0.062257513f,
723                                                    -0.00093483756f, 0.048938446f,   -0.004952862f,  -0.007730018f,
724                                                    -0.04043371f,    -0.017094059f,  0.07229206f,    -0.023670016f,
725                                                    -0.052195564f,   -0.025616996f,  -0.01520939f,   0.045104615f,
726                                                    -0.007376126f,   0.003533447f,   0.006570588f,   0.056037236f,
727                                                    0.12436656f,     0.051817212f,   0.028532185f,   -0.08686856f,
728                                                    0.11868599f,     0.07663395f,    -0.07323171f,   0.03463402f,
729                                                    -0.050708205f,   -0.04458982f,   -0.11590894f,   0.021273347f,
730                                                    0.1251325f,      -0.15313013f,   -0.12224372f,   0.17228661f,
731                                                    0.023029093f,    0.086124025f,   0.006445803f,   -0.03496501f,
732                                                    0.028332196f,    0.04449512f,    -0.042436164f,  -0.026587414f,
733                                                    -0.006041347f,   -0.09292539f,   -0.05678812f,   0.03897832f,
734                                                    0.09465633f,     0.008115513f,   -0.02171956f,   0.08304309f,
735                                                    0.071401566f,    0.019622514f,   0.032163795f,   -0.004167056f,
736                                                    0.02295182f,     0.030739572f,   0.056506045f,   0.004612461f,
737                                                    0.06524936f,     0.059999723f,   0.046395954f,   -0.0045512207f,
738                                                    -0.1335546f,     -0.030136576f,  0.11584653f,    -0.014678886f,
739                                                    0.0020118146f,   -0.09688814f,   -0.0790206f,    0.039770417f,
740                                                    -0.0329582f,     0.07922767f,    0.029322514f,   0.026405897f,
741                                                    0.04207835f,     -0.07073373f,   0.063781224f,   0.0859677f,
742                                                    -0.10925287f,    -0.07011058f,   0.048005477f,   0.03438226f,
743                                                    -0.09606514f,    -0.006669445f,  -0.043381985f,  0.04240257f,
744                                                    -0.06955775f,    -0.06769346f,   0.043903265f,   -0.026784198f,
745                                                    -0.017840602f,   0.024307009f,   -0.040079936f,  -0.019946516f,
746                                                    0.045318738f,    -0.12233574f,   0.026170589f,   0.0074471775f,
747                                                    0.15978073f,     0.10185836f,    0.10298046f,    -0.015476589f,
748                                                    -0.039390966f,   -0.072174534f,  0.0739445f,     -0.1211869f,
749                                                    -0.0347889f,     -0.07943156f,   0.014809798f,   -0.12412325f,
750                                                    -0.0030663363f,  0.039695457f,   0.0647603f,     -0.08291318f,
751                                                    -0.018529687f,   -0.004423833f,  0.0037507233f,  0.084633216f,
752                                                    -0.01514876f,    -0.056505352f,  -0.012800942f,  -0.06994386f,
753                                                    0.012962922f,    -0.031234352f,  0.07029052f,    0.016418684f,
754                                                    0.03618972f,     0.055686004f,   -0.08663945f,   -0.017404709f,
755                                                    -0.054761406f,   0.029065743f,   0.052404847f,   0.020238016f,
756                                                    0.0048197987f,   -0.0214882f,    0.07078733f,    0.013016777f,
757                                                    0.06262858f,     0.009184685f,   0.020785125f,   -0.043904778f,
758                                                    -0.0270329f,     -0.03299152f,   -0.060088247f,  -0.015162964f,
759                                                    -0.001828936f,   0.12642565f,    -0.056757294f,  0.013586685f,
760                                                    0.09232601f,     -0.035886683f,  0.06000002f,    0.05229691f,
761                                                    -0.052580316f,   -0.082029596f,  -0.010794592f,  0.012947712f,
762                                                    -0.036429964f,   -0.085508935f,  -0.13127148f,   -0.017744139f,
763                                                    0.031502828f,    0.036232427f,   -0.031581745f,  0.023051167f,
764                                                    -0.05325106f,    -0.03421577f,   0.028793324f,   -0.034633752f,
765                                                    -0.009881397f,   -0.043551125f,  -0.018609839f,  0.0019097115f,
766                                                    -0.008799762f,   0.056595087f,   0.0022273948f,  0.055752404f };
767 
768     std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
769                                                    -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
770                                                    -0.029587349f, -0.044576716f,  -0.07480124f,  -0.082868785f,
771                                                    0.023254942f,    0.027502948f, -0.0039728214f, -0.08683098f,
772                                                    -0.08116779f,  -0.014675607f,   -0.037924774f, -0.023314456f,
773                                                    -0.007401714f, -0.09255757f,  0.029460307f,    -0.08829125f,
774                                                     -0.005139627f,  -0.08989442f,  -0.0555066f,   0.13596267f,
775                                                    -0.025062224f, -0.048351806f,  -0.03850004f,  0.07266485f,
776                                                    -0.022414139f,   0.05940088f, 0.075114764f,   0.09597592f,
777                                                    -0.010211725f, -0.0049794707f,  -0.011523867f, -0.025980417f,
778                                                    0.072999895f,  0.11091378f,   -0.081685916f,   0.014416728f,
779                                                     0.043229222f,   0.034178585f,  -0.07530371f,  0.035837382f,
780                                                    -0.085607f, -0.007721233f,  -0.03287832f,  -0.043848954f,
781                                                    -0.06404588f,    -0.06632928f, -0.073643476f,  0.008214239f,
782                                                    -0.045984086f, 0.039764922f,    0.03474462f, 0.060612556f,
783                                                    -0.080590084f, 0.049127717f,  0.04151091f,     -0.030063879f,
784                                                     0.008801774f,   -0.023021035f, -0.019558564f, 0.05158114f,
785                                                    -0.010947698f, -0.011825728f,  0.0075720972f, 0.0699727f,
786                                                    -0.0039981045f,  0.069350146f, 0.08799282f,    0.016156472f,
787                                                    0.035502106f,  0.11695009f,     0.006217345f, 0.13392477f,
788                                                    -0.037875112f, 0.025745004f,  0.08940699f,     -0.00924166f,
789                                                     0.0046702605f,  -0.036598757f, -0.08811812f,  0.10522024f,
790                                                    -0.032441203f, 0.008176899f,   -0.04454919f,  0.07058152f,
791                                                    0.0067963637f,   0.039206743f, 0.03259838f,    0.03725492f,
792                                                    -0.09515802f,  0.013326398f,    -0.052055415f, -0.025676316f,
793                                                    0.03198509f,   -0.015951829f, -0.058556724f,   0.036879618f,
794                                                     0.043357447f,   0.028362012f,  -0.05908629f,  0.0059240665f,
795                                                    -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
796                                                    0.08800015f, 0.035250366f,   -0.022165963f, -0.07328642f,
797                                                    -0.009415526f,   -0.07455109f, 0.11690406f,    0.0363299f,
798                                                    0.07411125f,   0.042103454f,    -0.009660886f, 0.019076364f,
799                                                    0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
800                                                    -0.051502608f, 0.08979574f,   -0.051670972f,   0.04940282f,
801                                                     -0.07491107f,   -0.021240504f, 0.022596184f,  -0.034280192f,
802                                                    0.060163025f, -0.058211457f,  -0.051837247f, -0.01349775f,
803                                                    -0.04639988f,    -0.035936575f, -0.011681591f,  0.064818054f,
804                                                    0.0073146066f, -0.021745546f,   -0.043124277f, -0.06471268f,
805                                                    -0.07053354f,  -0.029321948f, -0.05330136f,    0.016933719f,
806                                                     -0.053782392f,  0.13747959f,   -0.1361751f,   -0.11569455f,
807                                                    0.0033329215f, 0.05693899f,    -0.053219706f, 0.063698f,
808                                                    0.07977434f,     -0.07924483f, 0.06936997f,    0.0034815092f,
809                                                    -0.007305279f, -0.037325785f,   -0.07251102f, -0.033633437f,
810                                                    -0.08677009f,  0.091591336f,  -0.14165086f,    0.021752775f,
811                                                     0.019683983f,   0.0011612234f, -0.058154266f, 0.049996935f,
812                                                    0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
813                                                    0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
814                                                    -0.026382286f, -0.0014920527f, 0.042667855f,  0.0018776858f,
815                                                    0.02986552f,     0.009814309f, 0.0733756f,     0.12289186f,
816                                                    0.018043943f,  -0.0458958f,     0.049412545f, 0.033632483f,
817                                                    0.05495232f,   0.036686596f,  -0.013781798f,   -0.010036754f,
818                                                     0.02576849f,    -0.08307328f,  0.010112348f,  0.042521734f,
819                                                    -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
820                                                    -0.023077697f, 0.10285965f,    0.084736146f,  0.15568255f,
821                                                    -0.00040734606f, 0.027835453f, -0.10292561f,   -0.032401145f,
822                                                    0.10053256f,   -0.026142767f,   -0.08271222f, -0.0030240538f,
823                                                    -0.016368777f, 0.1070414f,    0.042672627f,    0.013456989f,
824                                                     -0.0437609f,    -0.022309763f, 0.11576483f,   0.04108048f,
825                                                    0.061026827f, -0.0190714f,  -0.0869359f, 0.037901703f,  0.0610107f,
826                                                    0.07202949f, 0.01675338f,    0.086139716f,  -0.08795751f,
827                                                    -0.014898893f,   -0.023771819f, -0.01965048f,   0.007955471f,
828                                                    -0.043740474f, 0.03346837f,     -0.10549954f, 0.090567775f,
829                                                    0.042013682f,  -0.03176985f,  0.12569028f,     -0.02421228f,
830                                                     -0.029526481f,  0.023851605f,  0.031539805f,  0.05292009f,
831                                                    -0.02344001f, -0.07811758f,   -0.08834428f,  0.10094801f,
832                                                    0.16594367f,     -0.06861939f, -0.021256343f,  -0.041093912f,
833                                                    -0.06669611f,  0.035498552f,    0.021757556f, -0.09302526f,
834                                                    -0.015403468f, -0.06614931f,  -0.051798206f,   -0.013874718f,
835                                                     0.03630673f,    0.010412845f,  -0.08077351f,  0.046185967f,
836                                                    0.0035662893f, 0.03541868f,    -0.094149634f, -0.034814864f,
837                                                    0.003128424f,    -0.020674974f, -0.03944324f,   -0.008110165f,
838                                                    -0.11113267f,  0.08484226f,     0.043586485f, 0.040582247f,
839                                                    0.0968012f,    -0.065249965f, -0.028036479f,   0.0050708856f,
840                                                     0.0017462453f,  0.0326779f,    0.041296225f,  0.09164146f,
841                                                    -0.047743853f, -0.015952192f,  -0.034451712f, 0.084197424f,
842                                                    -0.05347844f,    -0.11768019f, 0.085926116f,   -0.08251791f,
843                                                    -0.045081906f, 0.0948852f,      0.068401024f, 0.024856757f,
844                                                    0.06978981f,   -0.057309967f, -0.012775832f,   -0.0032452994f,
845                                                    0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
846 
847     std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f,  0.24704495f,  0.018586371f, -0.037586458f,
848                                             -0.15312155f, -0.11812848f,  -0.11465643f, 0.20259799f,   0.11418174f,
849                                             -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
850                                             0.21198851f,  -0.38871562f,  -0.09061183f, -0.09683246f,  -0.21929175f};
851 
852 
853     std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f,   -0.012770197f, 0.041331276f,
854                                               -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
855                                               -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f,   -0.020432774f,
856                                               0.64658105f,   -0.06650122f,  -0.03467612f,  0.095340036f, 0.23647355f};
857 
858     std::vector<float> cellToOutputWeights = { 0.08286371f,  -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
859                                                -0.5495371f,  -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
860                                                -0.11940523f, 0.007358328f, 0.1890978f,   0.4833202f,   -0.34441817f,
861                                                0.36312827f,  -0.26375428f, 0.1457655f,   -0.19724406f, 0.15548733f};
862 
863     std::vector<float> projectionWeights={-0.009802181f,  0.09401916f,    0.0717386f,     -0.13895074f,  0.09641832f,
864                                           0.060420845f,   0.08539281f,    0.054285463f,   0.061395317f,  0.034448683f,
865                                           -0.042991187f,  0.019801661f,   -0.16840284f,   -0.015726732f, -0.23041931f,
866                                           -0.024478018f,  -0.10959692f,   -0.013875541f,  0.18600968f,   -0.061274476f,
867                                           0.0138165f,     -0.08160894f,   -0.07661644f,   0.032372914f,  0.16169067f,
868                                           0.22465782f,    -0.03993472f,   -0.004017731f,  0.08633481f,   -0.28869787f,
869                                           0.08682067f,    0.17240396f,    0.014975425f,   0.056431185f,  0.031037588f,
870                                           0.16702051f,    0.0077946745f,  0.15140012f,    0.29405436f,   0.120285f,
871                                           -0.188994f,     -0.027265169f,  0.043389652f,   -0.022061434f, 0.014777949f,
872                                           -0.20203483f,   0.094781205f,   0.19100232f,    0.13987629f,   -0.036132768f,
873                                           -0.06426278f,   -0.05108664f,   0.13221376f,    0.009441198f,  -0.16715929f,
874                                           0.15859416f,    -0.040437475f,  0.050779544f,   -0.022187516f, 0.012166504f,
875                                           0.027685808f,   -0.07675938f,   -0.0055694645f, -0.09444123f,  0.0046453946f,
876                                           0.050794356f,   0.10770313f,    -0.20790008f,   -0.07149004f,  -0.11425117f,
877                                           0.008225835f,   -0.035802525f,  0.14374903f,    0.15262283f,   0.048710253f,
878                                           0.1847461f,     -0.007487823f,  0.11000021f,    -0.09542012f,  0.22619456f,
879                                           -0.029149994f,  0.08527916f,    0.009043713f,   0.0042746216f, 0.016261552f,
880                                           0.022461696f,   0.12689082f,    -0.043589946f,  -0.12035478f,  -0.08361797f,
881                                           -0.050666027f,  -0.1248618f,    -0.1275799f,    -0.071875185f, 0.07377272f,
882                                           0.09944291f,    -0.18897448f,   -0.1593054f,    -0.06526116f,  -0.040107165f,
883                                           -0.004618631f,  -0.067624845f,  -0.007576253f,  0.10727444f,   0.041546922f,
884                                           -0.20424393f,   0.06907816f,    0.050412357f,   0.00724631f,   0.039827548f,
885                                           0.12449835f,    0.10747581f,    0.13708383f,    0.09134148f,   -0.12617786f,
886                                           -0.06428341f,   0.09956831f,    0.1208086f,     -0.14676677f,  -0.0727722f,
887                                           0.1126304f,     0.010139365f,   0.015571211f,   -0.038128063f, 0.022913318f,
888                                           -0.042050496f,  0.16842307f,    -0.060597885f,  0.10531834f,   -0.06411776f,
889                                           -0.07451711f,   -0.03410368f,   -0.13393489f,   0.06534304f,   0.003620307f,
890                                           0.04490757f,    0.05970546f,    0.05197996f,    0.02839995f,   0.10434969f,
891                                           -0.013699693f,  -0.028353551f,  -0.07260381f,   0.047201227f,  -0.024575593f,
892                                           -0.036445823f,  0.07155557f,    0.009672501f,   -0.02328883f,  0.009533515f,
893                                           -0.03606021f,   -0.07421458f,   -0.028082801f,  -0.2678904f,   -0.13221288f,
894                                           0.18419984f,    -0.13012612f,   -0.014588381f,  -0.035059117f, -0.04824723f,
895                                           0.07830115f,    -0.056184657f,  0.03277091f,    0.025466874f,  0.14494097f,
896                                           -0.12522776f,   -0.098633975f,  -0.10766018f,   -0.08317623f,  0.08594209f,
897                                           0.07749552f,    0.039474737f,   0.1776665f,     -0.07409566f,  -0.0477268f,
898                                           0.29323658f,    0.10801441f,    0.1154011f,     0.013952499f,  0.10739139f,
899                                           0.10708251f,    -0.051456142f,  0.0074137426f,  -0.10430189f,  0.10034707f,
900                                           0.045594677f,   0.0635285f,     -0.0715442f,    -0.089667566f, -0.10811871f,
901                                           0.00026344223f, 0.08298446f,    -0.009525053f,  0.006585689f,  -0.24567553f,
902                                           -0.09450807f,   0.09648481f,    0.026996298f,   -0.06419476f,  -0.04752702f,
903                                           -0.11063944f,   -0.23441927f,   -0.17608605f,   -0.052156363f, 0.067035615f,
904                                           0.19271925f,    -0.0032889997f, -0.043264326f,  0.09663576f,   -0.057112187f,
905                                           -0.10100678f,   0.0628376f,     0.04447668f,    0.017961001f,  -0.10094388f,
906                                           -0.10190601f,   0.18335468f,    0.10494553f,    -0.052095775f, -0.0026118709f,
907                                           0.10539724f,    -0.04383912f,   -0.042349473f,  0.08438151f,   -0.1947263f,
908                                           0.02251204f,    0.11216432f,    -0.10307853f,   0.17351969f,   -0.039091777f,
909                                           0.08066188f,    -0.00561982f,   0.12633002f,    0.11335965f,   -0.0088127935f,
910                                           -0.019777594f,  0.06864014f,    -0.059751723f,  0.016233567f,  -0.06894641f,
911                                           -0.28651384f,   -0.004228674f,  0.019708522f,   -0.16305895f,  -0.07468996f,
912                                           -0.0855457f,    0.099339016f,   -0.07580735f,   -0.13775392f,  0.08434318f,
913                                           0.08330512f,    -0.12131499f,   0.031935584f,   0.09180414f,   -0.08876437f,
914                                           -0.08049874f,   0.008753825f,   0.03498998f,    0.030215185f,  0.03907079f,
915                                           0.089751154f,   0.029194152f,   -0.03337423f,   -0.019092513f, 0.04331237f,
916                                           0.04299654f,    -0.036394123f,  -0.12915532f,   0.09793732f,   0.07512415f,
917                                           -0.11319543f,   -0.032502122f,  0.15661901f,    0.07671967f,   -0.005491124f,
918                                           -0.19379048f,   -0.218606f,     0.21448623f,    0.017840758f,  0.1416943f,
919                                           -0.07051762f,   0.19488361f,    0.02664691f,    -0.18104725f,  -0.09334311f,
920                                           0.15026465f,    -0.15493552f,   -0.057762887f,  -0.11604192f,  -0.262013f,
921                                           -0.01391798f,   0.012185008f,   0.11156489f,    -0.07483202f,  0.06693364f,
922                                           -0.26151478f,   0.046425626f,   0.036540434f,   -0.16435726f,  0.17338543f,
923                                           -0.21401681f,   -0.11385144f,   -0.08283257f,   -0.069031075f, 0.030635102f,
924                                           0.010969227f,   0.11109743f,    0.010919218f,   0.027526086f,  0.13519906f,
925                                           0.01891392f,    -0.046839405f,  -0.040167913f,  0.017953383f,  -0.09700955f,
926                                           0.0061885654f,  -0.07000971f,   0.026893595f,   -0.038844477f, 0.14543656f};
927 
928     std::vector<float> projectionBiasVector(outputSize, 0.f);
929 
930     armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo20x5);
931     armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo20x5);
932     armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo20x5);
933     armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo20x5);
934     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo20x16);
935     armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo20x16);
936     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo20x16);
937     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo20x16);
938     armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo20);
939     armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo20);
940     armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo20);
941     armnn::ScopedTensorHandle cellBiasTensor(tensorInfo20);
942     armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo20);
943     armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo20);
944     armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo20);
945     armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo16x20);
946     armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo16);
947 
948     AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
949     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
950     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
951     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
952     AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
953     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
954     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
955     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
956     AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
957     AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
958     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
959     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
960     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
961     AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
962     AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
963     AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
964     AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
965 
966     data.m_InputToInputWeights = &inputToInputWeightsTensor;
967     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
968     data.m_InputToCellWeights = &inputToCellWeightsTensor;
969     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
970     data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
971     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
972     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
973     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
974     data.m_CellToInputWeights = &cellToInputWeightsTensor;
975     data.m_InputGateBias = &inputGateBiasTensor;
976     data.m_ForgetGateBias = &forgetGateBiasTensor;
977     data.m_CellBias = &cellBiasTensor;
978     data.m_OutputGateBias = &outputGateBiasTensor;
979     data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
980     data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
981     data.m_ProjectionWeights = &projectionWeightsTensor;
982     data.m_ProjectionBias = &projectionBiasTensor;
983 
984     // Flags to set test configuration
985     data.m_Parameters.m_ActivationFunc = 4;
986     data.m_Parameters.m_CifgEnabled = false;
987     data.m_Parameters.m_PeepholeEnabled = true;
988     data.m_Parameters.m_ProjectionEnabled = true;
989 
990     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
991     inputHandle->Allocate();
992     outputStateInHandle->Allocate();
993     cellStateInHandle->Allocate();
994 
995     scratchHandle->Allocate();
996     outputStateOutHandle->Allocate();
997     cellStateOutHandle->Allocate();
998     outputHandle->Allocate();
999 
1000     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1001     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1002     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
1003 
1004     workload->Execute();
1005 
1006     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1007 
1008     return LayerTestResult<T, 2>(actualOutput,
1009                                  outputVector,
1010                                  outputHandle->GetShape(),
1011                                  outputTensorInfo.GetShape());
1012 }
1013 
1014 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<T> & input,const std::vector<T> & outputExpected,const armnn::TensorShape & inputShape,const armnn::TensorShape & outputExpectedShape,float qScale=1.0f,int32_t qOffset=0,armnn::DataType constantDataType=armnn::DataType::Float32)1015 LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(
1016         armnn::IWorkloadFactory& workloadFactory,
1017         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
1018         const armnn::ITensorHandleFactory& tensorHandleFactory,
1019         const std::vector<T>& input,
1020         const std::vector<T>& outputExpected,
1021         const armnn::TensorShape& inputShape,
1022         const armnn::TensorShape& outputExpectedShape,
1023         float qScale = 1.0f,
1024         int32_t qOffset = 0,
1025         armnn::DataType constantDataType = armnn::DataType::Float32)
1026 {
1027     IgnoreUnused(memoryManager);
1028     bool cifgEnabled = true;
1029     bool peepholeEnabled = true;
1030     bool projectionEnabled = false;
1031     // These are not the input and the output of Lstm yet
1032     unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
1033     unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
1034 
1035     unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
1036 
1037     const unsigned int cellSize = outputSize;
1038 
1039     // Decide the shape of all input tensors
1040     armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); // change to ArmnnType
1041     armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1042     armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1043 
1044     unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
1045     armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1046     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1047     armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1048     armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1049 
1050     // List of inputs
1051     std::vector<float> inputData;
1052     inputData.assign(input.data(), input.data() + batchSize*inputSize);
1053 
1054     std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1055 
1056     std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
1057 
1058     // Prepare all the weights in the descriptor for LSTM
1059     armnn::LstmQueueDescriptor data;
1060     armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1061     armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1062     armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
1063 
1064     std::vector<float> inputToCellWeights =
1065     {
1066         -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1067         0.04717243f, 0.48944736f, -0.38535351f,
1068         -0.17212132f
1069     };
1070     std::vector<float> inputToForgetWeights =
1071     {
1072         -0.55291498f, -0.42866567f, 0.13056988f,
1073         -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1074         0.33826375f
1075     };
1076     std::vector<float> inputToOutputWeights =
1077     {
1078         0.10725588f, -0.02335852f, -0.55932593f,
1079         -0.09426838f, -0.44257352f, 0.54939759f,
1080         0.01533556f, 0.42751634f
1081     };
1082     std::vector<float> cellBias =  {0.f, 0.f, 0.f, 0.f};
1083     std::vector<float> forgetGateBias =  {1.f, 1.f, 1.f, 1.f};
1084     std::vector<float> outputGateBias =  {0.f, 0.f, 0.f, 0.f};
1085 
1086     std::vector<float> recurrentToCellWeights =
1087     {
1088         0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1089         0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1090         0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1091         0.21193194f
1092     };
1093     std::vector<float> recurrentToForgetWeights =
1094     {
1095         -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1096         0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1097         -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1098     };
1099 
1100     std::vector<float> recurrentToOutputWeights =
1101     {
1102         0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1103         -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1104         0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1105     };
1106 
1107     std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1108     std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
1109 
1110     armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfoInput);
1111     armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfoInput);
1112     armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfoInput);
1113 
1114     armnn::ScopedTensorHandle cellBiasTensor(tensorInfoNumUnits);
1115     armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfoNumUnits);
1116     armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfoNumUnits);
1117 
1118     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfoOutput);
1119     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfoOutput);
1120     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfoOutput);
1121 
1122     armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfoNumUnits);
1123     armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfoNumUnits);
1124 
1125     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1126     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1127     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1128 
1129     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1130     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1131     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1132 
1133     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1134     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1135     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1136 
1137     AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1138     AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
1139 
1140     data.m_InputToCellWeights = &inputToCellWeightsTensor;
1141     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1142     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1143 
1144     data.m_CellBias = &cellBiasTensor;
1145     data.m_ForgetGateBias = &forgetGateBiasTensor;
1146     data.m_OutputGateBias = &outputGateBiasTensor;
1147 
1148     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1149     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1150     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1151 
1152     data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1153     data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1154 
1155     // other parameters for the descriptor
1156     data.m_Parameters.m_CifgEnabled = cifgEnabled;
1157     data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1158     data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1159 
1160     data.m_Parameters.m_ActivationFunc = 4;
1161     data.m_Parameters.m_ClippingThresProj = 0.0;
1162     data.m_Parameters.m_ClippingThresCell = 0.0;
1163 
1164     // List of outputs
1165     std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
1166     LayerTestResult<T, 2> ret0(scratchBufferTensorInfo);
1167 
1168     // Output state for a certain time step
1169     std::vector<T> outputStateOutVector(batchSize * outputSize, T());
1170     LayerTestResult<T, 2> ret1(outputStateOutTensorInfo);
1171 
1172     // Cell state for a certain time step
1173     std::vector<T> cellStateOutVector(batchSize * cellSize, T());
1174     LayerTestResult<T, 2> ret2(cellStateOutTensorInfo);
1175 
1176     // Output for a certain time step
1177     std::vector<T> outputData;
1178     outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
1179     LayerTestResult<T, 2> ret3(outputTensorInfo);
1180     ret3.m_ExpectedData = outputData;
1181 
1182     std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1183     std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1184     std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1185     std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
1186 
1187     // Prepare the inputs and outputs for the workload
1188     std::unique_ptr<armnn::ITensorHandle> inputHandle =
1189             tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
1190     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1191             tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
1192     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1193             tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
1194 
1195     std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
1196             tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
1197     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1198             tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
1199     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1200             tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
1201     std::unique_ptr<armnn::ITensorHandle> outputHandle =
1202             tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
1203 
1204     armnn::WorkloadInfo info;
1205     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1206     AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1207     AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1208 
1209     AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1210     AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1211     AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1212     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1213 
1214     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
1215 
1216     inputHandle->Allocate();
1217     outputStateInHandle->Allocate();
1218     cellStateInHandle->Allocate();
1219 
1220     scratchBufferHandle->Allocate();
1221     outputStateOutHandle->Allocate();
1222     cellStateOutHandle->Allocate();
1223     outputHandle->Allocate();
1224 
1225     CopyDataToITensorHandle(inputHandle.get(), inputData.data());
1226     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1227     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
1228 
1229     CopyDataToITensorHandle(scratchBufferHandle.get(), scratchBufferVector.data());
1230     CopyDataToITensorHandle(outputStateOutHandle.get(), outputStateOutVector.data());
1231     CopyDataToITensorHandle(cellStateOutHandle.get(), cellStateOutVector.data());
1232 
1233     workload->Execute();
1234 
1235     CopyDataFromITensorHandle(actualScratchBufferOutput.data(), scratchBufferHandle.get());
1236     CopyDataFromITensorHandle(actualOutputStateOutput.data(), outputStateOutHandle.get());
1237     CopyDataFromITensorHandle(actualCellStateOutput.data(), cellStateOutHandle.get());
1238     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1239 
1240     ret0.m_ActualData = actualScratchBufferOutput;
1241     ret1.m_ActualData = actualOutputStateOutput;
1242     ret2.m_ActualData = actualCellStateOutput;
1243     ret3.m_ActualData = actualOutput;
1244 
1245     return ret3;
1246 }
1247 
1248 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1249 LayerTestResult<T, 2>
LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<T> & input,const std::vector<T> & outputExpected,float qScale=1.0f,int32_t qOffset=0,armnn::DataType constantDataType=armnn::DataType::Float32)1250 LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
1251                                                   const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
1252                                                   const armnn::ITensorHandleFactory& tensorHandleFactory,
1253                                                   const std::vector<T>& input,
1254                                                   const std::vector<T>& outputExpected,
1255                                                   float qScale = 1.0f,
1256                                                   int32_t qOffset = 0,
1257                                                   armnn::DataType constantDataType = armnn::DataType::Float32)
1258 {
1259     IgnoreUnused(memoryManager);
1260     unsigned int batchSize = 2;
1261     unsigned int outputSize = 3;
1262     unsigned int inputSize = 5;
1263     unsigned numUnits = 4;
1264 
1265     armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1266     armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1267     armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1268 
1269     // Scratch buffer size without CIFG [batchSize, numUnits * 4]
1270     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1271     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1272     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273     armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1274 
1275     std::vector<float> inputVector;
1276     inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
1277 
1278     std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
1279     std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1280     std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
1281     std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
1282     std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
1283 
1284     std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
1285 
1286     std::vector<float> outputVector;
1287     outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
1288 
1289     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
1290     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1291             tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
1292     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1293             tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
1294 
1295     std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1296             tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
1297     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1298             tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
1299     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1300             tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
1301     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
1302 
1303     armnn::LstmQueueDescriptor data;
1304     armnn::WorkloadInfo info;
1305 
1306     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1307     AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1308     AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1309 
1310     AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1311     AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1312     AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1313     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1314 
1315     armnn::TensorInfo tensorInfo3({outputSize}, constantDataType, qScale, qOffset);
1316     armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset);
1317     armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1318     armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1319     armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1320 
1321     std::vector<float> inputToInputWeights = {0.5f,  0.6f,  0.7f, -0.8f, -0.9f,
1322                                               0.1f,  0.2f,  0.3f, -0.4f,  0.5f,
1323                                              -0.8f,  0.7f, -0.6f,  0.5f, -0.4f,
1324                                              -0.5f, -0.4f, -0.3f, -0.2f, -0.1f};  //{numUnits, inputSize}
1325 
1326     std::vector<float> inputToForgetWeights = { -0.6f, -0.1f,  0.3f,  0.2f,  0.9f,
1327                                                 -0.5f, -0.2f, -0.4f,  0.3f, -0.8f,
1328                                                 -0.4f,  0.3f, -0.5f, -0.4f, -0.6f,
1329                                                  0.3f, -0.4f, -0.6f, -0.5f, -0.5f};  //{numUnits, inputSize}
1330 
1331     std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1332                                               0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1333                                               0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1334                                               0.7f, -0.9f, -0.5f,  0.8f,  0.6f};  //{numUnits, inputSize}
1335 
1336     std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1337                                                -0.7f,  0.3f, -0.3f, -0.8f, -0.2f,
1338                                                 0.6f, -0.2f,  0.4f, -0.7f, -0.3f,
1339                                                -0.5f,  0.1f,  0.5f, -0.6f, -0.4f}; //{numUnits, inputSize}
1340 
1341     std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f};  //{numUnits}
1342 
1343     std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f};    //{numUnits}
1344 
1345     std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f}; //{numUnits}
1346 
1347     std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f};   //{numUnits}
1348 
1349     std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f,  0.4f,
1350                                                   0.1f, -0.5f,  0.9f,
1351                                                  -0.2f, -0.3f, -0.7f,
1352                                                  0.05f, -0.2f, -0.6f};  //{numUnits, outputSize}
1353 
1354     std::vector<float> recurrentToCellWeights = {-0.3f,  0.2f,   0.1f,
1355                                                  -0.3f,  0.8f, -0.08f,
1356                                                  -0.2f,  0.3f,   0.8f,
1357                                                  -0.6f, -0.1f,   0.2f}; //{numUnits, outputSize}
1358 
1359     std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1360                                                     -0.2f,  0.6f,  0.4f,
1361                                                      0.9f,  0.3f, -0.1f,
1362                                                      0.2f,  0.5f,  0.2f};  //{numUnits, outputSize}
1363 
1364     std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f,  0.1f,
1365                                                    -0.2f, -0.5f, -0.7f,
1366                                                    -0.2f, -0.6f, -0.1f,
1367                                                    -0.4f, -0.7f, -0.2f};  //{numUnits, outputSize}
1368 
1369     std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f};      //{numUnits}
1370 
1371     std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f}; //{numUnits}
1372 
1373     std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f};      //{numUnits}
1374 
1375     std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1376                                              0.1f, 0.5f,  0.3f, 0.08f,
1377                                              0.07f, 0.2f, -0.4f,  0.2f}; //{outputSize, numUnits}
1378 
1379     std::vector<float> projectionBiasVector(outputSize, 0.f); //{outputSize}
1380 
1381     std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f}; //{numUnits}
1382 
1383     std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f}; //{numUnits}
1384 
1385     std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f}; //{numUnits}
1386 
1387     std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f}; //{numUnits}
1388 
1389 
1390     armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
1391     armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
1392     armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
1393     armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
1394     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
1395     armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
1396     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
1397     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
1398     armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
1399     armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
1400     armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
1401     armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
1402     armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
1403     armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4);
1404     armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4);
1405     armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo3x4);
1406     armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo3);
1407 
1408     armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
1409     armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
1410     armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
1411     armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
1412 
1413     AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1414     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1415     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1416     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1417     AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1418     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1419     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1420     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1421     AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
1422     AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1423     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1424     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1425     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1426     AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1427     AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
1428     AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
1429     AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
1430 
1431     AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
1432     AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1433     AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1434     AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
1435 
1436     data.m_InputToInputWeights = &inputToInputWeightsTensor;
1437     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1438     data.m_InputToCellWeights = &inputToCellWeightsTensor;
1439     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1440     data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1441     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1442     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1443     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1444     data.m_CellToInputWeights = &cellToInputWeightsTensor;
1445     data.m_InputGateBias = &inputGateBiasTensor;
1446     data.m_ForgetGateBias = &forgetGateBiasTensor;
1447     data.m_CellBias = &cellBiasTensor;
1448     data.m_OutputGateBias = &outputGateBiasTensor;
1449     data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1450     data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1451     data.m_ProjectionWeights = &projectionWeightsTensor;
1452     data.m_ProjectionBias = &projectionBiasTensor;
1453 
1454     data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
1455     data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1456     data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1457     data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1458 
1459     // Flags to set test configuration
1460     data.m_Parameters.m_ActivationFunc = 4;
1461     data.m_Parameters.m_CifgEnabled = false;
1462     data.m_Parameters.m_PeepholeEnabled = true;
1463     data.m_Parameters.m_ProjectionEnabled = true;
1464     data.m_Parameters.m_LayerNormEnabled = true;
1465 
1466 
1467     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
1468     inputHandle->Allocate();
1469     outputStateInHandle->Allocate();
1470     cellStateInHandle->Allocate();
1471 
1472     scratchHandle->Allocate();
1473     outputStateOutHandle->Allocate();
1474     cellStateOutHandle->Allocate();
1475     outputHandle->Allocate();
1476 
1477     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1478     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1479     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
1480 
1481     workload->Execute();
1482 
1483     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1484 
1485     return LayerTestResult<T, 2>(actualOutput,
1486                                  outputVector,
1487                                  outputHandle->GetShape(),
1488                                  outputTensorInfo.GetShape());
1489 }
1490 
QuantizedLstmTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<uint8_t> & input,const std::vector<uint8_t> & outputExpected,const armnn::TensorShape & inputShape,const armnn::TensorShape & outputExpectedShape)1491 LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl(
1492     armnn::IWorkloadFactory& workloadFactory,
1493     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
1494     const armnn::ITensorHandleFactory& tensorHandleFactory,
1495     const std::vector<uint8_t>& input,
1496     const std::vector<uint8_t>& outputExpected,
1497     const armnn::TensorShape& inputShape,
1498     const armnn::TensorShape& outputExpectedShape)
1499 {
1500     IgnoreUnused(memoryManager);
1501     auto numBatches = armnn::numeric_cast<unsigned int>(inputShape[0]);
1502     auto inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
1503     auto outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
1504 
1505     // Scale/Offset for input/output, cellState In/Out, weights, bias
1506     float inputOutputScale = 0.0078125f;
1507     int32_t inputOutputOffset = 128;
1508 
1509     float cellStateScale = 0.00048828125f;
1510     int32_t cellStateOffset = 0;
1511 
1512     float weightsScale = 0.00408021f;
1513     int32_t weightsOffset = 100;
1514 
1515     float biasScale = 3.1876640625e-05f;
1516     int32_t biasOffset = 0;
1517 
1518     // Input/Output tensor info
1519     armnn::TensorInfo inputInfo({numBatches , inputSize},
1520                                  armnn::DataType::QAsymmU8,
1521                                  inputOutputScale,
1522                                  inputOutputOffset);
1523 
1524     armnn::TensorInfo cellStateInfo({numBatches , outputSize},
1525                                      armnn::DataType::QSymmS16,
1526                                      cellStateScale,
1527                                      cellStateOffset);
1528 
1529     armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1530                                        armnn::DataType::QAsymmU8,
1531                                        inputOutputScale,
1532                                        inputOutputOffset);
1533 
1534     // Input0
1535     std::vector<uint8_t> inputVector;
1536     inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1537 
1538     // Input1
1539     std::vector<int16_t> cellStateInVector   = {876, 1034, 955, -909, 761, 1029, 796, -1036}; // 13
1540     // Input2
1541     std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112}; // 14
1542 
1543     // Output0
1544     std::vector<int16_t> cellStateOutVector  = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235}; // 0
1545 
1546     // Output1
1547     std::vector<uint8_t> outputVector; // 1
1548     outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1549 
1550     std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
1551 
1552     // Create tensor handles
1553     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
1554     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1555             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1556     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1557             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
1558 
1559     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1560             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1561     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
1562 
1563     armnn::QuantizedLstmQueueDescriptor data;
1564     armnn::WorkloadInfo info;
1565 
1566     // Add inputs and outputs to workload
1567     AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1568     AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1569     AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1570 
1571     AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1572     AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1573 
1574     // Weights and bias tensor and quantization info
1575     armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1576                                         armnn::DataType::QAsymmU8,
1577                                         weightsScale,
1578                                         weightsOffset);
1579 
1580     armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1581                                             armnn::DataType::QAsymmU8,
1582                                             weightsScale,
1583                                             weightsOffset);
1584 
1585     armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1586 
1587     // Weights and bias tensor data
1588     std::vector<uint8_t> inputToInputWeights  = {146, 250, 235, 171, 10, 218, 171, 108};
1589     std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1590     std::vector<uint8_t> inputToCellWeights   = {133, 34, 29, 49, 206, 109, 54, 183};
1591     std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
1592 
1593     std::vector<uint8_t> recurrentToInputWeights =
1594             {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1595     std::vector<uint8_t> recurrentToForgetWeights =
1596             {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1597     std::vector<uint8_t> recurrentToCellWeights =
1598             {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1599     std::vector<uint8_t> recurrentToOutputWeights =
1600             {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
1601 
1602     std::vector<int32_t> inputGateBias  = {-7876, 13488, -726, 32839};
1603     std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1604     std::vector<int32_t> cellBias       = {39481, 48624, 48976, -21419};
1605     std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
1606 
1607     // ScopedTensorHandles
1608     armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
1609     armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1610     armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1611     armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
1612 
1613     armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
1614     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1615     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1616     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
1617 
1618     armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
1619     armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1620     armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1621     armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
1622 
1623     // Allocate and copy data
1624     AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1625     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1626     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1627     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1628 
1629     AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1630     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1631     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1632     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1633 
1634     AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1635     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1636     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1637     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1638 
1639     // Setup queue descriptor
1640     data.m_InputToInputWeights = &inputToInputWeightsTensor;
1641     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1642     data.m_InputToCellWeights = &inputToCellWeightsTensor;
1643     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1644 
1645     data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1646     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1647     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1648     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1649 
1650     data.m_InputGateBias = &inputGateBiasTensor;
1651     data.m_ForgetGateBias = &forgetGateBiasTensor;
1652     data.m_CellBias = &cellBiasTensor;
1653     data.m_OutputGateBias = &outputGateBiasTensor;
1654 
1655     // Create workload and allocate tensor handles
1656     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QuantizedLstm,
1657                                                                                 data,
1658                                                                                 info);
1659     inputHandle->Allocate();
1660     outputStateInHandle->Allocate();
1661     cellStateInHandle->Allocate();
1662 
1663     cellStateOutHandle->Allocate();
1664     outputHandle->Allocate();
1665 
1666     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1667     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1668     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
1669 
1670     workload->Execute();
1671 
1672     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1673 
1674     return LayerTestResult<uint8_t, 2>(actualOutput,
1675                                        outputVector,
1676                                        outputHandle->GetShape(),
1677                                        outputStateInfo.GetShape());
1678 }
1679 
1680 // QLSTM: CIFG, LayerNorm
QLstmTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<int8_t> & input,const std::vector<int8_t> & outputExpected)1681 LayerTestResult<int8_t, 2> QLstmTestImpl(
1682         armnn::IWorkloadFactory& workloadFactory,
1683         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
1684         const armnn::ITensorHandleFactory& tensorHandleFactory,
1685         const std::vector<int8_t>& input,
1686         const std::vector<int8_t>& outputExpected)
1687 {
1688     IgnoreUnused(memoryManager);
1689     unsigned int numBatches = 2;
1690     unsigned int inputSize  = 5;
1691     unsigned int outputSize = 4;
1692     unsigned int numUnits   = 4;
1693 
1694     bool cifgEnabled       = true;
1695     bool peepholeEnabled   = false;
1696     bool projectionEnabled = false;
1697     bool layerNormEnabled  = true;
1698 
1699     // Scale/Offset quantization info
1700     float inputScale    = 0.0078125f;
1701     int32_t inputOffset = 0;
1702 
1703     int32_t hiddenStateZeroPoint = 0;
1704     float hiddenStateScale       = 0.007f;
1705 
1706     // if (!projectionEnabled) outputScale == hiddenStateScale
1707     float outputScale    = hiddenStateScale;
1708     int32_t outputOffset = hiddenStateZeroPoint;
1709 
1710     float cellStateScale    = 3.05176e-05f;
1711     int32_t cellStateOffset = 0;
1712 
1713     float weightsScale    = 0.00784314f;
1714     int32_t weightsOffset = 0;
1715 
1716     float layerNormScale    = 3.05182e-05f;
1717     int32_t layerNormOffset = 0;
1718 
1719     float biasScale    = layerNormScale / 1024;
1720     int32_t biasOffset = 0;
1721 
1722     float inputIntermediateScale  = 0.007059f;
1723     float forgetIntermediateScale = 0.007812f;
1724     float cellIntermediateScale   = inputIntermediateScale;
1725     float outputIntermediateScale = forgetIntermediateScale;
1726 
1727     float cellClip       = 0.0f;
1728     float projectionClip = 0.0f;
1729 
1730     // Input/Output tensor info
1731     armnn::TensorInfo inputInfo({numBatches , inputSize},
1732                                 armnn::DataType::QAsymmS8,
1733                                 inputScale,
1734                                 inputOffset);
1735 
1736     armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1737                                     armnn::DataType::QSymmS16,
1738                                     cellStateScale,
1739                                     cellStateOffset);
1740 
1741     armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1742                                       armnn::DataType::QAsymmS8,
1743                                       outputScale,
1744                                       outputOffset);
1745 
1746     LayerTestResult<int8_t, 2> ret(outputStateInfo);
1747 
1748     // Input tensors
1749     std::vector<int8_t> inputVector;
1750     inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1751 
1752     std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1753 
1754     std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1755 
1756     // Output tensors
1757     std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
1758 
1759     std::vector<int8_t> outputVector;
1760     outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1761 
1762     std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
1763 
1764     // Create tensor handles
1765     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
1766     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1767             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1768     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1769             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
1770 
1771     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1772             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
1773     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1774             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1775     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
1776 
1777     armnn::QLstmQueueDescriptor data;
1778     armnn::WorkloadInfo info;
1779 
1780     // Add inputs and outputs to workload
1781     AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1782     AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1783     AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1784 
1785     AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1786     AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1787     AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1788 
1789     // Weights and bias tensor and quantization info
1790     armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1791                                        armnn::DataType::QSymmS8,
1792                                        weightsScale,
1793                                        weightsOffset);
1794 
1795     armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1796                                            armnn::DataType::QSymmS8,
1797                                            weightsScale,
1798                                            weightsOffset);
1799 
1800     armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1801 
1802     armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
1803 
1804     // Weights and bias tensor data
1805     std::vector<int8_t> inputToForgetWeights =
1806             {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1807     std::vector<int8_t> inputToCellWeights   =
1808             {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1809     std::vector<int8_t> inputToOutputWeights =
1810             {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
1811 
1812     std::vector<int8_t> recurrentToForgetWeights =
1813             {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1814     std::vector<int8_t> recurrentToCellWeights   =
1815             {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1816     std::vector<int8_t> recurrentToOutputWeights =
1817             {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
1818 
1819     std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1820     std::vector<int32_t> cellBias       = {-1073742, 15461883, 5368709, 1717987};
1821     std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
1822 
1823     std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1824     std::vector<int16_t> cellLayerNormWeights   = {22937, 6553, 9830, 26214};
1825     std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
1826 
1827     // ScopedTensorHandles
1828     armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1829     armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1830     armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
1831 
1832     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1833     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1834     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
1835 
1836     armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1837     armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1838     armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
1839 
1840     armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
1841     armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
1842     armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
1843 
1844     // Allocate and copy data
1845     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1846     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1847     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1848 
1849     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1850     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1851     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1852 
1853     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1854     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1855     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1856 
1857     AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1858     AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1859     AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
1860 
1861     // Setup queue descriptor
1862     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1863     data.m_InputToCellWeights = &inputToCellWeightsTensor;
1864     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1865 
1866     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1867     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1868     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1869 
1870     data.m_ForgetGateBias = &forgetGateBiasTensor;
1871     data.m_CellBias = &cellBiasTensor;
1872     data.m_OutputGateBias = &outputGateBiasTensor;
1873 
1874     data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1875     data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1876     data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1877 
1878     data.m_Parameters.m_CifgEnabled = cifgEnabled;
1879     data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1880     data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1881     data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
1882 
1883     data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
1884     data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
1885     data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
1886     data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
1887 
1888     data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
1889     data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
1890 
1891     data.m_Parameters.m_CellClip = cellClip;
1892     data.m_Parameters.m_ProjectionClip = projectionClip;
1893 
1894     // Create workload and allocate tensor handles
1895     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
1896     inputHandle->Allocate();
1897     outputStateInHandle->Allocate();
1898     cellStateInHandle->Allocate();
1899 
1900     outputStateOutHandle->Allocate();
1901     cellStateOutHandle->Allocate();
1902     outputHandle->Allocate();
1903 
1904     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1905     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1906     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
1907 
1908     workload->Execute();
1909 
1910     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1911 
1912     return LayerTestResult<int8_t, 2>(actualOutput,
1913                                       outputVector,
1914                                       outputHandle->GetShape(),
1915                                       outputStateInfo.GetShape());
1916 }
1917 
1918 // QLSTM: Projection, LayerNorm
QLstmTestImpl1(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<int8_t> & input,const std::vector<int8_t> & outputExpected)1919 LayerTestResult<int8_t, 2> QLstmTestImpl1(
1920         armnn::IWorkloadFactory& workloadFactory,
1921         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
1922         const armnn::ITensorHandleFactory& tensorHandleFactory,
1923         const std::vector<int8_t>& input,
1924         const std::vector<int8_t>& outputExpected)
1925 {
1926     IgnoreUnused(memoryManager);
1927     unsigned int numBatches = 2;
1928     unsigned int inputSize  = 5;
1929     unsigned int outputSize = 3;
1930     unsigned int numUnits   = 4;
1931 
1932     bool cifgEnabled       = false;
1933     bool peepholeEnabled   = false;
1934     bool projectionEnabled = true;
1935     bool layerNormEnabled  = true;
1936 
1937     // Scale/Offset quantization info
1938     float inputScale    = 0.0078125f;
1939     int32_t inputOffset = 0;
1940 
1941     int32_t hiddenStateZeroPoint = 0;
1942     float hiddenStateScale       = 0.007f;
1943 
1944     // if (!projectionEnabled) outputScale == hiddenStateScale
1945     float outputScale    = 3.05176e-05f;
1946     int32_t outputOffset = 0;
1947 
1948     float cellStateScale    = 3.05176e-05f;
1949     int32_t cellStateOffset = 0;
1950 
1951     float weightsScale    = 0.00784314f;
1952     int32_t weightsOffset = 0;
1953 
1954     float layerNormScale    = 3.05182e-05f;
1955     int32_t layerNormOffset = 0;
1956 
1957     float biasScale    = layerNormScale / 1024;
1958     int32_t biasOffset = 0;
1959 
1960     float projectionWeightsScale = 0.00392157f;
1961 
1962     float inputIntermediateScale  = 0.007059f;
1963     float forgetIntermediateScale = 0.007812f;
1964     float cellIntermediateScale   = inputIntermediateScale;
1965     float outputIntermediateScale = forgetIntermediateScale;
1966 
1967     float cellClip       = 0.0f;
1968     float projectionClip = 0.0f;
1969 
1970     // Input/Output tensor info
1971     armnn::TensorInfo inputInfo({numBatches , inputSize},
1972                                 armnn::DataType::QAsymmS8,
1973                                 inputScale,
1974                                 inputOffset);
1975 
1976     armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1977                                     armnn::DataType::QSymmS16,
1978                                     cellStateScale,
1979                                     cellStateOffset);
1980 
1981     armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1982                                       armnn::DataType::QAsymmS8,
1983                                       outputScale,
1984                                       outputOffset);
1985 
1986     // Input tensors
1987     std::vector<int8_t> inputVector;
1988     inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1989 
1990     std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1991 
1992     std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
1993 
1994     // Output tensors
1995     std::vector<int16_t> cellStateOutVector  = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
1996 
1997     std::vector<int8_t> outputVector;
1998     outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1999 
2000     std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2001 
2002     // Create tensor handles
2003     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
2004     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2005             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2006     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2007             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2008 
2009     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2010             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2011     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2012             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2013     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2014 
2015     armnn::QLstmQueueDescriptor data;
2016     armnn::WorkloadInfo info;
2017 
2018     // Add inputs and outputs to workload
2019     AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2020     AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2021     AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2022 
2023     AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2024     AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2025     AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2026 
2027     // Weights and bias tensor and quantization info
2028     armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2029                                        armnn::DataType::QSymmS8,
2030                                        weightsScale,
2031                                        weightsOffset);
2032 
2033     armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2034                                            armnn::DataType::QSymmS8,
2035                                            weightsScale,
2036                                            weightsOffset);
2037 
2038     armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2039 
2040     armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2041 
2042     armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2043                                             armnn::DataType::QSymmS8,
2044                                             projectionWeightsScale,
2045                                             0);
2046 
2047     // Weights and bias tensor data
2048     std::vector<int8_t> inputToInputWeights =
2049             {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2050     std::vector<int8_t> inputToForgetWeights =
2051             {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2052     std::vector<int8_t> inputToCellWeights   =
2053             {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2054     std::vector<int8_t> inputToOutputWeights =
2055             {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2056 
2057     std::vector<int8_t> recurrentToInputWeights  = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2058     std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2059     std::vector<int8_t> recurrentToCellWeights   = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2060     std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2061 
2062     std::vector<int32_t> inputGateBias  = {644245, 3221226, 4724464, 8160438};
2063     std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2064     std::vector<int32_t> cellBias       = {-1073742, 15461883, 5368709, 1717987};
2065     std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2066 
2067     std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2068     std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2069     std::vector<int16_t> cellLayerNormWeights   = {22937, 6553, 9830, 26214};
2070     std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2071 
2072     std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2073 
2074     // ScopedTensorHandles
2075     armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
2076     armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2077     armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2078     armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
2079 
2080     armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
2081     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2082     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2083     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
2084 
2085     armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
2086     armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2087     armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2088     armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
2089 
2090     armnn::ScopedTensorHandle inputLayerNormWeightsTensor(layerNormWeightsInfo);
2091     armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2092     armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2093     armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
2094 
2095     armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
2096 
2097     // Allocate and copy data
2098     AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
2099     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2100     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2101     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
2102 
2103     AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
2104     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2105     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2106     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
2107 
2108     AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
2109     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2110     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2111     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
2112 
2113     AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
2114     AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2115     AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2116     AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
2117 
2118     AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
2119 
2120     // Setup queue descriptor
2121     data.m_InputToInputWeights = &inputToInputWeightsTensor;
2122     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2123     data.m_InputToCellWeights = &inputToCellWeightsTensor;
2124     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2125 
2126     data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
2127     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2128     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2129     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2130 
2131     data.m_InputGateBias = &inputGateBiasTensor;
2132     data.m_ForgetGateBias = &forgetGateBiasTensor;
2133     data.m_CellBias = &cellBiasTensor;
2134     data.m_OutputGateBias = &outputGateBiasTensor;
2135 
2136     data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
2137     data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2138     data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2139     data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2140 
2141     data.m_ProjectionWeights = &projectionWeightsTensor;
2142 
2143     data.m_Parameters.m_CifgEnabled = cifgEnabled;
2144     data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2145     data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2146     data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2147 
2148     data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2149     data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2150     data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2151     data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2152 
2153     data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2154     data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2155 
2156     data.m_Parameters.m_CellClip = cellClip;
2157     data.m_Parameters.m_ProjectionClip = projectionClip;
2158 
2159     // Create workload and allocate tensor handles
2160     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
2161     inputHandle->Allocate();
2162     outputStateInHandle->Allocate();
2163     cellStateInHandle->Allocate();
2164 
2165     outputStateOutHandle->Allocate();
2166     cellStateOutHandle->Allocate();
2167     outputHandle->Allocate();
2168 
2169     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2170     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2171     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
2172 
2173     workload->Execute();
2174 
2175     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
2176 
2177     return LayerTestResult<int8_t, 2>(actualOutput,
2178                                       outputVector,
2179                                       outputHandle->GetShape(),
2180                                       outputStateInfo.GetShape());
2181 }
2182 
2183 // QLSTM: Projection, CIFG, LayerNorm
QLstmTestImpl2(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const std::vector<int8_t> & input,const std::vector<int8_t> & outputExpected)2184 LayerTestResult<int8_t, 2> QLstmTestImpl2(
2185         armnn::IWorkloadFactory& workloadFactory,
2186         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2187         const armnn::ITensorHandleFactory& tensorHandleFactory,
2188         const std::vector<int8_t>& input,
2189         const std::vector<int8_t>& outputExpected)
2190 {
2191     IgnoreUnused(memoryManager);
2192     unsigned int numBatches = 2;
2193     unsigned int inputSize  = 5;
2194     unsigned int outputSize = 3;
2195     unsigned int numUnits   = 4;
2196 
2197     bool cifgEnabled       = true;
2198     bool peepholeEnabled   = false;
2199     bool projectionEnabled = true;
2200     bool layerNormEnabled  = true;
2201 
2202     // Scale/Offset quantization info
2203     float inputScale    = 0.0078125f;
2204     int32_t inputOffset = 0;
2205 
2206     int32_t hiddenStateZeroPoint = 0;
2207     float hiddenStateScale       = 0.007f;
2208 
2209     // if (!projectionEnabled) outputScale == hiddenStateScale
2210     float outputScale    = 3.05176e-05f;
2211     int32_t outputOffset = 0;
2212 
2213     float cellStateScale    = 3.05176e-05f;
2214     int32_t cellStateOffset = 0;
2215 
2216     float weightsScale    = 0.00784314f;
2217     int32_t weightsOffset = 0;
2218 
2219     float layerNormScale    = 3.05182e-05f;
2220     int32_t layerNormOffset = 0;
2221 
2222     float biasScale    = layerNormScale / 1024;
2223     int32_t biasOffset = 0;
2224 
2225     float projectionWeightsScale = 0.00392157f;
2226 
2227     float inputIntermediateScale  = 0.007059f;
2228     float forgetIntermediateScale = 0.007812f;
2229     float cellIntermediateScale   = inputIntermediateScale;
2230     float outputIntermediateScale = forgetIntermediateScale;
2231 
2232     float cellClip       = 0.0f;
2233     float projectionClip = 0.0f;
2234 
2235     // Input/Output tensor info
2236     armnn::TensorInfo inputInfo({numBatches , inputSize},
2237                                 armnn::DataType::QAsymmS8,
2238                                 inputScale,
2239                                 inputOffset);
2240 
2241     armnn::TensorInfo cellStateInfo({numBatches , numUnits},
2242                                     armnn::DataType::QSymmS16,
2243                                     cellStateScale,
2244                                     cellStateOffset);
2245 
2246     armnn::TensorInfo outputStateInfo({numBatches , outputSize},
2247                                       armnn::DataType::QAsymmS8,
2248                                       outputScale,
2249                                       outputOffset);
2250 
2251     // Input tensors
2252     std::vector<int8_t> inputVector;
2253     inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
2254 
2255     std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
2256 
2257     std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
2258 
2259     // Output tensors
2260     std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
2261 
2262     std::vector<int8_t> outputVector;
2263     outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
2264 
2265     std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2266 
2267     // Create tensor handles
2268     std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
2269     std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2270             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2271     std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2272             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2273 
2274     std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2275             tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2276     std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2277             tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2278     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
2279 
2280     armnn::QLstmQueueDescriptor data;
2281     armnn::WorkloadInfo info;
2282 
2283     // Add inputs and outputs to workload
2284     AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2285     AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2286     AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2287 
2288     AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2289     AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2290     AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2291 
2292     // Weights and bias tensor and quantization info
2293     armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2294                                        armnn::DataType::QSymmS8,
2295                                        weightsScale,
2296                                        weightsOffset);
2297 
2298     armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2299                                            armnn::DataType::QSymmS8,
2300                                            weightsScale,
2301                                            weightsOffset);
2302 
2303     armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2304 
2305     armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2306 
2307     armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2308                                             armnn::DataType::QSymmS8,
2309                                             projectionWeightsScale,
2310                                             0);
2311 
2312     // Weights and bias tensor data
2313     std::vector<int8_t> inputToForgetWeights =
2314             {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2315     std::vector<int8_t> inputToCellWeights   =
2316             {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2317     std::vector<int8_t> inputToOutputWeights =
2318             {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2319 
2320     std::vector<int8_t> recurrentToForgetWeights =
2321             {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2322     std::vector<int8_t> recurrentToCellWeights =
2323             {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2324     std::vector<int8_t> recurrentToOutputWeights =
2325             {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2326 
2327     std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2328     std::vector<int32_t> cellBias       = {-1073742, 15461883, 5368709, 1717987};
2329     std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2330 
2331     std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2332     std::vector<int16_t> cellLayerNormWeights   = {22937, 6553, 9830, 26214};
2333     std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2334 
2335     std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2336 
2337     // ScopedTensorHandles
2338     armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2339     armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2340     armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
2341 
2342     armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2343     armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2344     armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
2345 
2346     armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2347     armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2348     armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
2349 
2350     armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2351     armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2352     armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
2353 
2354     armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
2355 
2356     // Allocate and copy data
2357     AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2358     AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2359     AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
2360 
2361     AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2362     AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2363     AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
2364 
2365     AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2366     AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2367     AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
2368 
2369     AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2370     AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2371     AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
2372 
2373     AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
2374 
2375     // Setup queue descriptor
2376     data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2377     data.m_InputToCellWeights = &inputToCellWeightsTensor;
2378     data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2379 
2380     data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2381     data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2382     data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2383 
2384     data.m_ForgetGateBias = &forgetGateBiasTensor;
2385     data.m_CellBias = &cellBiasTensor;
2386     data.m_OutputGateBias = &outputGateBiasTensor;
2387 
2388     data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2389     data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2390     data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2391 
2392     data.m_ProjectionWeights = &projectionWeightsTensor;
2393 
2394     data.m_Parameters.m_CifgEnabled = cifgEnabled;
2395     data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2396     data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2397     data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2398 
2399     data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2400     data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2401     data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2402     data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2403 
2404     data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2405     data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2406 
2407     data.m_Parameters.m_CellClip = cellClip;
2408     data.m_Parameters.m_ProjectionClip = projectionClip;
2409 
2410     // Create workload and allocate tensor handles
2411     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
2412     inputHandle->Allocate();
2413     outputStateInHandle->Allocate();
2414     cellStateInHandle->Allocate();
2415 
2416     outputStateOutHandle->Allocate();
2417     cellStateOutHandle->Allocate();
2418     outputHandle->Allocate();
2419 
2420     CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2421     CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2422     CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
2423 
2424     workload->Execute();
2425 
2426     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
2427 
2428     return LayerTestResult<int8_t, 2>(actualOutput,
2429                                       outputVector,
2430                                       outputHandle->GetShape(),
2431                                       outputStateInfo.GetShape());
2432 }
2433 
2434 
2435 } // anonymous namespace
2436 
2437 #if defined(ARMNNREF_ENABLED)
2438 
2439 // The LSTM test units are run only for the reference backend at the moment
2440 
LstmUtilsZeroVectorTest()2441 void LstmUtilsZeroVectorTest()
2442 {
2443     armnn::TensorInfo inputDesc({4}, armnn::DataType::Float32);
2444     std::vector<float> input = {2., 3., 3., 4.};
2445     std::vector<float> expectedOutput = {0., 0., 0., 0.};
2446 
2447     return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
2448 }
2449 
LstmUtilsMeanStddevNormalizationNoneZeroInputTest()2450 void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2451 {
2452     uint32_t batchSize = 2;
2453     uint32_t vecSize = 4;
2454     armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
2455     std::vector<float> input =
2456             { 0.1f, 0.2f, 0.3f, 0.4f,    //batch 0
2457               0.9f, 1.0f, 1.1f, 1.2f };  //batch 1
2458 
2459     std::vector<float> expectedOutput =
2460             { -1.34164071f, -0.447213531f, 0.44721365f,  1.34164071f,    //batch 0
2461               -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f  };  //batch 1
2462 
2463     return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2464             vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2465 }
2466 
LstmUtilsMeanStddevNormalizationAllZeroInputTest()2467 void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2468 {
2469     uint32_t batchSize = 2;
2470     uint32_t vecSize = 4;
2471     armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
2472     std::vector<float> input =
2473             { 0.0f, 0.0f, 0.0f, 0.0f,      //batch 0
2474               0.0f, 0.0f, 0.0f, 0.0f };  //batch 1
2475 
2476     std::vector<float> expectedOutput =
2477             { 0.0f, 0.0f, 0.0f, 0.0f,      //batch 0
2478               0.0f, 0.0f, 0.0f, 0.0f };  //batch 1
2479 
2480     return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2481             vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2482 }
2483 
LstmUtilsMeanStddevNormalizationMixedZeroInputTest()2484 void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2485 {
2486     uint32_t batchSize = 2;
2487     uint32_t vecSize = 4;
2488     armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
2489     std::vector<float> input =
2490             { 0.0f, 0.0f, 0.0f, 0.0f,    //batch 0
2491               0.1f, 0.2f, 0.3f, 0.4f };  //batch 1
2492 
2493     std::vector<float> expectedOutput =
2494             {         0.0f,          0.0f,        0.0f,        0.0f,    //batch 0
2495               -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f };  //batch 1
2496 
2497     return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2498             vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2499 }
2500 
LstmUtilsVectorBatchVectorCwiseProductTest()2501 void LstmUtilsVectorBatchVectorCwiseProductTest()
2502 {
2503     uint32_t batchSize = 4;
2504     uint32_t vecSize = 29;
2505     armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
2506     std::vector<float> vector =
2507             {   1.1f,   2.2f,   3.3f,   4.4f,   5.5f,   6.6f,   7.7f,   8.8f,   9.9f, 10.1f,
2508               11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2509               21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f,     0.0f};
2510 
2511     armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
2512     std::vector<float> batchVector =
2513             { /* batch 0 */
2514                 1.1f,   2.2f,   3.3f,   4.4f,   5.5f,   6.6f,   7.7f,   8.8f,   9.9f,  10.1f,
2515               11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f,  20.2f,
2516               21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f,   0.0f,
2517               /* batch 1 */
2518                 -1.1f,   -2.2f,   -3.3f,   -4.4f,   -5.5f,   -6.6f,   -7.7f,   -8.8f,   -9.9f, -10.1f,
2519               -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2520               -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f,    0.0f,
2521               /* batch 2 */
2522                 1.1f,   -2.2f,   3.3f,   -4.4f,   5.5f,   -6.6f,   7.7f,   -8.8f,   9.9f, -10.1f,
2523               11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2524               21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f,   0.0f,
2525               /* batch 3 */
2526                 -1.1f,   2.2f,   -3.3f,   4.4f,   -5.5f,   6.6f,   -7.7f,   8.8f,   -9.9f, 10.1f,
2527               -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
2528               -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f,    0.0f};
2529 
2530     // Expect output = input * output + output.
2531     std::vector<float> expectedOutput =
2532             { /* batch 0 */
2533                  1.210000f,    4.840000f,   10.889999f,   19.360001f,   30.250000f,   43.559998f,
2534                 59.289997f,   77.440002f,   98.009995f,  102.010010f,  123.432091f,  146.894394f,
2535                172.396896f,  199.939606f,  229.522491f,  261.145599f,  294.808899f,  330.512421f,
2536                368.256134f,  408.040039f,  449.864075f,  493.728363f,  539.632874f,  587.577576f,
2537                637.562500f,  689.587585f,  743.652954f,  799.758423f,    0.000000f,
2538               /* batch 1 */
2539                 -1.210000f,   -4.840000f,  -10.889999f,  -19.360001f,  -30.250000f,  -43.559998f,
2540                -59.289997f,  -77.440002f,  -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2541               -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2542               -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2543               -637.562500f, -689.587585f, -743.652954f, -799.758423f,    0.000000f,
2544               /* batch 2 */
2545                  1.210000f,   -4.840000f,  10.889999f,   -19.360001f,   30.250000f,  -43.559998f,
2546                 59.289997f,  -77.440002f,  98.009995f,  -102.010010f,  123.432091f, -146.894394f,
2547                172.396896f, -199.939606f, 229.522491f,  -261.145599f,  294.808899f, -330.512421f,
2548                368.256134f, -408.040039f, 449.864075f,  -493.728363f,  539.632874f, -587.577576f,
2549                637.562500f, -689.587585f, 743.652954f,  -799.758423f,    0.000000f,
2550               /* batch 3 */
2551                 -1.210000f,    4.840000f,  -10.889999f,   19.360001f,  -30.250000f,   43.559998f,
2552                -59.289997f,   77.440002f,  -98.009995f,  102.010010f, -123.432091f,  146.894394f,
2553               -172.396896f,  199.939606f, -229.522491f,  261.145599f, -294.808899f,  330.512421f,
2554               -368.256134f,  408.040039f, -449.864075f,  493.728363f, -539.632874f,  587.577576f,
2555               -637.562500f,  689.587585f, -743.652954f,  799.758423f,    0.000000f};
2556 
2557     return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
2558             vecSize, batchSize, expectedOutput, vecDesc.GetShape());
2559 }
2560 
LstmUtilsVectorBatchVectorAddTest()2561 void LstmUtilsVectorBatchVectorAddTest()
2562 {
2563     uint32_t batchSize = 2;
2564     uint32_t vecSize = 3;
2565     armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
2566     std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
2567 
2568     armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
2569     std::vector<float> batchVector =
2570     {
2571         1.0f, 2.0f, 3.0f, //batch 0
2572         4.0f, 5.0f, 6.0f  //batch 1
2573     };
2574 
2575     std::vector<float> expectedOutput =
2576     {
2577         1.0f, 1.5f, 4.0f,
2578         4.0f, 4.5f, 7.0f
2579     };
2580 
2581     return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
2582             vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
2583 }
2584 
2585 #endif
2586 
LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2587 LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(
2588     armnn::IWorkloadFactory& workloadFactory,
2589     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2590     const armnn::ITensorHandleFactory& tensorHandleFactory)
2591 {
2592     armnn::TensorInfo inputDesc({ 2, 2 }, armnn::DataType::Float32);
2593     std::vector<float> input = { 2., 3., 3., 4. };
2594 
2595     armnn::TensorInfo outputDesc({ 2, 4 }, armnn::DataType::Float32);
2596     std::vector<float> expectedOutput =
2597             {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2598              -0.42734814f, -0.00478661f,  0.13455015f, -0.03560682f};
2599     return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2600         workloadFactory, memoryManager, tensorHandleFactory,
2601         input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2602 }
2603 
LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2604 LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(
2605     armnn::IWorkloadFactory& workloadFactory,
2606     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2607     const armnn::ITensorHandleFactory& tensorHandleFactory)
2608 {
2609     armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
2610     std::vector<float> input =
2611             {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2612              0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
2613 
2614     armnn::TensorInfo outputDesc({ 2, 16 }, armnn::DataType::Float32);
2615     std::vector<float> expectedOutput =
2616             {-0.00396806f,  0.029352f,    -0.00279226f,  0.0159977f,   -0.00835576f,
2617              -0.0211779f,   0.0283512f,   -0.0114597f,   0.00907307f,  -0.0244004f,
2618              -0.0152191f,  -0.0259063f,    0.00914318f,  0.00415118f,   0.017147f,
2619               0.0134203f,  -0.013869f,     0.0287268f,  -0.00334693f,   0.00733398f,  -0.0287926f,
2620              -0.0186926f,   0.0193662f,   -0.0115437f,   0.00422612f,  -0.0345232f,
2621               0.00223253f, -0.00957321f,   0.0210624f,   0.013331f,     0.0150954f,    0.02168f};
2622     return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
2623         workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2624 }
2625 
LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2626 LayerTestResult<float, 2> LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(
2627     armnn::IWorkloadFactory& workloadFactory,
2628     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2629     const armnn::ITensorHandleFactory& tensorHandleFactory)
2630 {
2631     armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::Float32);
2632     std::vector<float> input = {2., 3., 3., 4.};
2633 
2634     armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::Float32);
2635     std::vector<float> expectedOutput =
2636             {-0.02973187f, 0.1229473f,   0.20885126f, -0.15358765f,
2637              -0.0185422f,   0.11281417f,  0.24466537f, -0.1826292f};
2638 
2639     return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2640         workloadFactory, memoryManager, tensorHandleFactory,
2641         input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2642 }
2643 
LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2644 LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(
2645     armnn::IWorkloadFactory& workloadFactory,
2646     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2647     const armnn::ITensorHandleFactory& tensorHandleFactory)
2648 {
2649     armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
2650     std::vector<float> input =
2651             {0.7f, 0.8f, 0.1f, 0.2f, 0.3f,     //batch 0
2652              0.3f, 0.2f, 0.9f, 0.8f, 0.1f};  //batch 1
2653 
2654     armnn::TensorInfo outputDesc({ 2, 3 }, armnn::DataType::Float32);
2655     std::vector<float> expectedOutput =
2656             {  0.0244077f,  0.128027f, -0.00170918f,    //batch 0
2657               -0.00692428f, 0.0848741f,    0.063445f};  //batch 1
2658     return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
2659         workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2660 }
2661 
LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2662 LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(
2663     armnn::IWorkloadFactory& workloadFactory,
2664     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2665     const armnn::ITensorHandleFactory& tensorHandleFactory)
2666 {
2667     const float qScale = 1.0f;
2668     const int32_t qOffset = 0;
2669 
2670     const armnn::DataType datatype = armnn::DataType::QSymmS16;
2671     const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
2672 
2673     armnn::TensorInfo inputDesc({2, 2}, datatype);
2674     std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2675 
2676     armnn::TensorInfo outputDesc({2, 4}, datatype);
2677     std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2678         {
2679             -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2680             -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2681         },
2682         qScale, qOffset);
2683 
2684     return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2685         workloadFactory, memoryManager, tensorHandleFactory,
2686         input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2687         qScale, qOffset, constantDatatype);
2688 
2689 }
2690 
LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2691 LayerTestResult<int16_t, 2> LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(
2692     armnn::IWorkloadFactory& workloadFactory,
2693     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2694     const armnn::ITensorHandleFactory& tensorHandleFactory)
2695 {
2696     const float qScale = 1.0f;
2697     const int32_t qOffset = 0;
2698 
2699     const armnn::DataType datatype = armnn::DataType::QSymmS16;
2700     const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
2701 
2702     armnn::TensorInfo inputDesc({ 2, 2 }, datatype);
2703     std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2704 
2705     armnn::TensorInfo outputDesc({ 2, 4 }, datatype);
2706     std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2707         {
2708             -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2709             -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2710         },
2711         qScale, qOffset);
2712 
2713     return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
2714         workloadFactory, memoryManager, tensorHandleFactory,
2715         input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2716         qScale, qOffset, constantDatatype);
2717 }
2718 
LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2719 LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(
2720     armnn::IWorkloadFactory& workloadFactory,
2721     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2722     const armnn::ITensorHandleFactory& tensorHandleFactory)
2723 {
2724     const float qScale = 2.0f;
2725     const int32_t qOffset = 0;
2726 
2727     const armnn::DataType datatype = armnn::DataType::QSymmS16;
2728     const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
2729 
2730     armnn::TensorInfo inputDesc({ 2, 5 }, datatype);
2731     std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2732         {
2733             0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2734             0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2735         },
2736         qScale, qOffset);
2737 
2738     armnn::TensorInfo outputDesc({ 2, 16 }, datatype);
2739     std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2740         {
2741             -0.00396806f,  0.02935200f, -0.00279226f,  0.01599770f,
2742             -0.00835576f, -0.02117790f,  0.02835120f, -0.01145970f,
2743              0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2744              0.00914318f,  0.00415118f,  0.01714700f,  0.01342030f,
2745             -0.01386900f,  0.02872680f, -0.00334693f,  0.00733398f,
2746             -0.02879260f, -0.01869260f,  0.01936620f, -0.01154370f,
2747              0.00422612f, -0.03452320f,  0.00223253f, -0.00957321f,
2748              0.02106240f,  0.01333100f,  0.01509540f,  0.02168000f
2749         },
2750         qScale, qOffset);
2751 
2752     return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
2753         workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
2754 }
2755 
LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2756 LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(
2757     armnn::IWorkloadFactory& workloadFactory,
2758     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2759     const armnn::ITensorHandleFactory& tensorHandleFactory)
2760 {
2761     const float qScale = 1.0f;
2762     const int32_t qOffset = 0;
2763 
2764     const armnn::DataType datatype = armnn::DataType::QSymmS16; // datatype & constants set to QSymm16
2765 
2766     armnn::TensorInfo inputDesc({2, 2}, datatype);
2767     std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2768 
2769     armnn::TensorInfo outputDesc({2, 4}, datatype);
2770     std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2771         {
2772             -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2773             -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2774         },
2775         qScale, qOffset);
2776 
2777     return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2778         workloadFactory, memoryManager, tensorHandleFactory,
2779         input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2780         qScale, qOffset, datatype);
2781 }
2782 
2783 //
2784 // QuantizedLstm
2785 //
2786 
QuantizedLstmTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2787 LayerTestResult<uint8_t, 2> QuantizedLstmTest(
2788     armnn::IWorkloadFactory& workloadFactory,
2789     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2790     const armnn::ITensorHandleFactory& tensorHandleFactory)
2791 {
2792     armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::QAsymmU8);
2793     std::vector<uint8_t> input = {166, 179, 50, 150};
2794 
2795     armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmU8);
2796     std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
2797 
2798     return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2799                                  input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2800 }
2801 
2802 // QLSTM
QLstmTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2803 LayerTestResult<int8_t, 2> QLstmTest(
2804     armnn::IWorkloadFactory& workloadFactory,
2805     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2806     const armnn::ITensorHandleFactory& tensorHandleFactory)
2807 {
2808     armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
2809     std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2810 
2811     armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8);
2812     std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
2813 
2814     return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2815 }
2816 
QLstmTest1(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2817 LayerTestResult<int8_t, 2> QLstmTest1(
2818     armnn::IWorkloadFactory& workloadFactory,
2819     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2820     const armnn::ITensorHandleFactory& tensorHandleFactory)
2821 {
2822     armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
2823     std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2824 
2825     armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
2826     std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
2827 
2828     return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2829 }
2830 
QLstmTest2(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)2831 LayerTestResult<int8_t, 2> QLstmTest2(
2832     armnn::IWorkloadFactory& workloadFactory,
2833     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2834     const armnn::ITensorHandleFactory& tensorHandleFactory)
2835 {
2836     armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
2837     std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2838 
2839     armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
2840     std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
2841 
2842     return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2843 }