• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonUnidirectionalSequenceLstmWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8 
9 #include <aclCommon/ArmComputeUtils.hpp>
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnnUtils/Permute.hpp>
14 #include <neon/test/NeonWorkloadFactoryHelper.hpp>
15 #include <backendsCommon/WorkloadUtils.hpp>
16 
17 #include "neon/NeonTensorHandle.hpp"
18 
19 namespace
20 {
21 
CalcAclAxis(unsigned int numDimensions,unsigned int axis)22 unsigned int CalcAclAxis(unsigned int numDimensions, unsigned int axis)
23 {
24     return (numDimensions - axis) - 1;
25 }
26 } //namespace
27 
28 namespace armnn
29 {
30 using namespace armcomputetensorutils;
31 
NeonUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor & descriptor,const WorkloadInfo & info)32 NeonUnidirectionalSequenceLstmWorkload::NeonUnidirectionalSequenceLstmWorkload
33     (const UnidirectionalSequenceLstmQueueDescriptor& descriptor, const WorkloadInfo& info)
34     : NeonBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
35 {
36     // Report Profiling Details
37     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonUnidirectionalSequenceLstmWorkload_Construct",
38                                          descriptor.m_Parameters,
39                                          info,
40                                          GetGuid());
41 
42     // Input/Output tensors
43     const arm_compute::ITensor& input         = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
44     arm_compute::ITensor& outputStateIn       = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
45     const arm_compute::ITensor& cellStateIn   = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
46 
47     arm_compute::ITensor& outputStateOut = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
48     arm_compute::ITensor& cellStateOut   = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
49     arm_compute::ITensor& output         = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
50 
51     TensorInfo inputInfo = info.m_InputTensorInfos[0];
52     TensorInfo outputInfo = info.m_OutputTensorInfos[2];
53 
54     TensorShape inputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetShape();
55     TensorShape outputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetShape();
56 
57     unsigned int maxTime = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
58     unsigned int batchSize = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
59     unsigned int inputSize = inputLayerShape[2];
60     unsigned int outputSize = outputLayerShape[2];
61 
62     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
63     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
64 
65     //
66     // Permute: performed if Unidirectional Sequence Layer inputs/outputs are in batch major format.
67     //
68     if (!m_Data.m_Parameters.m_TimeMajor)
69     {
70         std::unique_ptr<arm_compute::NEPermute> layer(new arm_compute::NEPermute());
71 
72         TensorInfo permuteOutInfo = inputInfo;
73         permuteOutInfo.SetShape(timeMajorShapeInput);
74         BuildArmComputeTensor(m_PermuteFirstOut, permuteOutInfo);
75         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermuteFirstOut);
76 
77         // Permute to time major format.
78         layer->configure(&input, &m_PermuteFirstOut, arm_compute::PermutationVector(0U,2U,1U));
79         m_Permute1.reset(layer.release());
80     }
81 
82     //
83     // Split and Concat Tensors
84     //
85     for (unsigned int i = 0; i < maxTime; ++i)
86     {
87         arm_compute::Tensor splitter_out;
88         arm_compute::Tensor concat_in;
89 
90         auto splitterTensorInfo = inputInfo;
91         auto concatTensorInfo = outputInfo;
92         splitterTensorInfo.SetShape({batchSize, inputSize});
93         concatTensorInfo.SetShape({batchSize, outputSize});
94         BuildArmComputeTensor(splitter_out, splitterTensorInfo);
95         BuildArmComputeTensor(concat_in, concatTensorInfo);
96 
97         armcomputetensorutils::InitialiseArmComputeTensorEmpty(splitter_out);
98         armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_in);
99 
100         // append to std::vector<arm_compute::Tensor>
101         m_SplitterOutputsTensors.push_back(std::move(splitter_out));
102         m_ConcatInputsTensors.push_back(std::move(concat_in));
103     }
104 
105     for (unsigned int i = 0; i < maxTime; ++i)
106     {
107         // append to std::vector<arm_compute::ITensor*>
108         m_SplitterOutputs.push_back(&m_SplitterOutputsTensors[i]);
109         m_ConcatInputs.push_back(&m_ConcatInputsTensors[i]);
110     }
111 
112     //
113     // Split
114     //
115     unsigned int numberDimensions = 3;
116     unsigned int dimension = 0; // splitting on 0-dimension (i.e. maxTime dimension)
117 
118     if (maxTime != 1) // ACL split does not work with only one element to split.
119     {
120         ViewsDescriptor splitterDesc(maxTime, numberDimensions);
121         unsigned int splitterDimSizes[3] = {1, batchSize, inputSize};
122         for (unsigned int outputIdx = 0u; outputIdx < maxTime; ++outputIdx)
123         {
124             splitterDesc.SetViewOriginCoord(outputIdx, dimension, splitterDimSizes[dimension] * outputIdx);
125             for (unsigned int dimIdx = 0u; dimIdx < numberDimensions; ++dimIdx)
126             {
127                 splitterDesc.SetViewSize(outputIdx, dimIdx, splitterDimSizes[dimIdx]);
128             }
129         }
130 
131         std::set<unsigned int> splitAxis = ComputeSplitAxis(splitterDesc, timeMajorShapeInput);
132 
133         std::unique_ptr<arm_compute::NESplit> split_layer(new arm_compute::NESplit());
134         unsigned int                          aclAxisSplit = CalcAclAxis(splitterDesc.GetNumDimensions(),
135                                                                          *splitAxis.begin());
136         if (!m_Data.m_Parameters.m_TimeMajor)
137         {
138             split_layer->configure(&m_PermuteFirstOut, m_SplitterOutputs, aclAxisSplit);
139         } else
140         {
141             split_layer->configure(&input, m_SplitterOutputs, aclAxisSplit);
142         }
143 
144         split_layer->prepare();
145         m_Splitter.reset(split_layer.release());
146     }
147 
148     //
149     // Lstm
150     //
151     arm_compute::LSTMParams<arm_compute::ITensor> lstm_param;
152 
153     lstm_param.set_cell_clip_params(descriptor.m_Parameters.m_ClippingThresCell);
154     lstm_param.set_projection_clip_params(descriptor.m_Parameters.m_ClippingThresProj);
155 
156     lstm_param.set_matmul_scale_params(descriptor.m_Parameters.m_InputIntermediateScale,
157                                        descriptor.m_Parameters.m_ForgetIntermediateScale,
158                                        descriptor.m_Parameters.m_CellIntermediateScale,
159                                        descriptor.m_Parameters.m_OutputIntermediateScale);
160 
161     lstm_param.set_hidden_state_params(descriptor.m_Parameters.m_HiddenStateZeroPoint,
162                                        descriptor.m_Parameters.m_HiddenStateScale);
163 
164     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
165     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
166 
167     m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
168     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
169 
170     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
171     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
172 
173     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
174     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
175 
176     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
177     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
178 
179     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
180     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
181 
182     m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
183     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
184 
185     m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
186     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
187 
188     m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
189     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
190 
191     // for future reference: check the AndroidNN API for the logic here
192     if (!m_Data.m_Parameters.m_CifgEnabled)
193     {
194         m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
195         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
196 
197         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
198         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
199 
200         m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
201         if (m_Data.m_CellToInputWeights != nullptr)
202         {
203             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
204         }
205 
206         m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
207         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
208         lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
209                                    m_RecurrentToInputWeightsTensor.get(),
210                                    m_Data.m_CellToInputWeights ? m_CellToInputWeightsTensor.get() : nullptr,
211                                    m_InputGateBiasTensor.get());
212     }
213 
214     if (m_Data.m_Parameters.m_ProjectionEnabled)
215     {
216         m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
217         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
218 
219         m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
220         if (m_Data.m_ProjectionBias != nullptr)
221         {
222             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
223         }
224 
225         lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
226                                          m_Data.m_ProjectionBias ? m_ProjectionBiasTensor.get() : nullptr);
227     }
228 
229     if (m_Data.m_Parameters.m_PeepholeEnabled)
230     {
231         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
232         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
233 
234         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
235         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
236 
237         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
238     }
239 
240     if (m_Data.m_Parameters.m_LayerNormEnabled)
241     {
242         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
243         if (!m_Data.m_Parameters.m_CifgEnabled)
244         {
245             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
246         }
247 
248         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
249         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
250 
251         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
252         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
253 
254         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
255         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
256 
257         auto inputNormWeightTensor = m_Data.m_Parameters.m_CifgEnabled ? nullptr : m_InputLayerNormWeightsTensor.get();
258         lstm_param.set_layer_normalization_params(inputNormWeightTensor,
259                                                   m_ForgetLayerNormWeightsTensor.get(),
260                                                   m_CellLayerNormWeightsTensor.get(),
261                                                   m_OutputLayerNormWeightsTensor.get());
262     }
263 
264     for (unsigned int i = 0; i != maxTime; ++i)
265     {
266         // Set LSTM input and output ITensors depending on:
267         // input format (timeMajor) & number of LSTM batches (maxTime).
268         arm_compute::ITensor* outputLSTM;
269         arm_compute::ITensor* inputLSTM;
270 
271         // If there is only one LSTM time major batch, we will not concat OR permute.
272         // Set input of LSTM to be first input ITensor.
273         // Set output of LSTM to be final output ITensor.
274         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
275         if (maxTime == 1 && m_Data.m_Parameters.m_TimeMajor)
276         {
277             TensorShape inputShape = GetTensorShape(input.info()->tensor_shape(), 1U);
278             TensorShape outputShape = GetTensorShape(output.info()->tensor_shape(), 1U);
279 
280             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
281             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
282 
283             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
284             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
285 
286             input.info()->set_tensor_shape(acl_input_shape_shrink);
287             inputLSTM = const_cast<arm_compute::ITensor*>(&input);
288 
289             output.info()->set_tensor_shape(acl_output_shape_shrink);
290             outputLSTM = &output;
291         }
292         // If there is only one LSTM batch major batch, we will not concat, only permute.
293         // Set input of LSTM to be output of initial permute.
294         // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
295         // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
296         else if (maxTime == 1 && !m_Data.m_Parameters.m_TimeMajor)
297         {
298             TensorShape inputShape = GetTensorShape(m_PermuteFirstOut.info()->tensor_shape(), 1U);
299             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
300             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
301             m_PermuteFirstOut.info()->set_tensor_shape(acl_input_shape_shrink);
302             inputLSTM = &m_PermuteFirstOut;
303 
304             outputLSTM = const_cast<arm_compute::ITensor*>(m_ConcatInputs[i]);
305         }
306         // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
307         else
308         {
309             inputLSTM = m_SplitterOutputs[i];
310             outputLSTM = const_cast<arm_compute::ITensor*>(m_ConcatInputs[i]);
311         }
312 
313         std::unique_ptr<arm_compute::NEQLSTMLayer> lstm_layer(new arm_compute::NEQLSTMLayer());
314 
315         lstm_layer->configure(inputLSTM,
316                               m_InputToForgetWeightsTensor.get(),
317                               m_InputToCellWeightsTensor.get(),
318                               m_InputToOutputWeightsTensor.get(),
319                               m_RecurrentToForgetWeightsTensor.get(),
320                               m_RecurrentToCellWeightsTensor.get(),
321                               m_RecurrentToOutputWeightsTensor.get(),
322                               m_ForgetGateBiasTensor.get(),
323                               m_CellBiasTensor.get(),
324                               m_OutputGateBiasTensor.get(),
325                               &cellStateIn,
326                               &outputStateIn,
327                               &cellStateOut,
328                               &outputStateOut,
329                               outputLSTM,
330                               lstm_param);
331 
332         m_Layers.emplace_back(std::move(lstm_layer));
333     }
334 
335     InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
336     InitializeArmComputeTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
337     InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
338     InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
339     InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
340     InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
341     InitializeArmComputeTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
342     InitializeArmComputeTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
343     InitializeArmComputeTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
344 
345     if (!m_Data.m_Parameters.m_CifgEnabled)
346     {
347         InitializeArmComputeTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
348         InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
349         if (m_Data.m_CellToInputWeights != nullptr)
350         {
351             InitializeArmComputeTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
352         }
353         InitializeArmComputeTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
354     }
355 
356     if (m_Data.m_Parameters.m_ProjectionEnabled)
357     {
358         InitializeArmComputeTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
359         if (m_Data.m_ProjectionBias != nullptr)
360         {
361             InitializeArmComputeTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
362         }
363     }
364 
365     if (m_Data.m_Parameters.m_PeepholeEnabled)
366     {
367         InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
368         InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
369     }
370 
371     if (m_Data.m_Parameters.m_LayerNormEnabled)
372     {
373         if (!m_Data.m_Parameters.m_CifgEnabled)
374         {
375             InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
376         }
377         InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
378         InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
379         InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
380     }
381 
382     // Force Compute Library to perform the necessary copying and reshaping.
383     // After which delete all the input tensors that will no longer be needed.
384     for (uint32_t i = 0; i < m_Layers.size(); ++i)
385     {
386         m_Layers[i]->prepare();
387     }
388 
389     //
390     // Concat
391     //
392 
393     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
394     TensorShape shape = GetTensorShape(m_ConcatInputs[0]->info()->tensor_shape(), 1U);
395     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
396     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
397 
398     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
399     {
400         for (unsigned int i = 0; i < maxTime; ++i)
401         {
402             m_ConcatInputs[i]->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
403         }
404         ConcatDescriptor  concatDescriptor(maxTime, numberDimensions);  // maxTime = num inputs (aka. number of views).
405 
406         for (unsigned int inputIdx = 0u; inputIdx < maxTime; ++inputIdx)
407         {
408             concatDescriptor.SetViewOriginCoord(inputIdx, dimension, inputIdx);
409             concatDescriptor.SetConcatAxis(dimension);
410         }
411         m_Concat.reset(new arm_compute::NEConcatenateLayer());
412 
413         unsigned int aclAxisConcat = CalcAclAxis(concatDescriptor.GetNumDimensions(), concatDescriptor.GetConcatAxis());
414         if (!m_Data.m_Parameters.m_TimeMajor)
415         {
416             TensorInfo concatOutputTensorInfo = outputInfo;
417             concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
418             BuildArmComputeTensor(concat_out, concatOutputTensorInfo);
419             armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_out);
420 
421             m_Concat->configure(m_ConcatInputs, &concat_out, aclAxisConcat);
422         }
423         else
424         {
425             m_Concat->configure(m_ConcatInputs, &output, aclAxisConcat);
426         }
427 
428         m_Concat->prepare();
429     }
430     // If only one LSTM batch, we do not concat and/or permute.
431     // Must ensure final output info is expanded to correct batch major dimensions.
432     else
433     {
434         if (!m_Data.m_Parameters.m_TimeMajor)
435         {
436             output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandBatchMajor));
437         }
438         else
439         {
440             output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
441         }
442     }
443 
444     //
445     // Permute: only done if input/output are in batch major format.
446     //
447     if (!m_Data.m_Parameters.m_TimeMajor)
448     {
449         // Output now time major. Permute output back to batch major.
450         std::unique_ptr<arm_compute::NEPermute> layer(new arm_compute::NEPermute());
451         if (maxTime != 1)
452         {
453             layer->configure(&concat_out, &output, arm_compute::PermutationVector(0U, 2U, 1U));
454         }
455         else
456         {
457             layer->configure(m_ConcatInputs[0], &output, arm_compute::PermutationVector(0U, 2U, 1U));
458         }
459         m_Permute2.reset(layer.release());
460     }
461 
462     FreeUnusedTensors();
463 }
464 
Execute() const465 void NeonUnidirectionalSequenceLstmWorkload::Execute() const
466 {
467     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonUnidirectionalSequenceLstmWorkload_Execute", GetGuid());
468     if (m_Permute1)
469     {
470         m_Permute1->run();
471     }
472     if (m_Splitter)
473     {
474         m_Splitter->run();
475     }
476     for (uint32_t i = 0; i < m_Layers.size(); ++i)
477     {
478         m_Layers[i]->run();
479     }
480     if (m_Concat)
481     {
482         m_Concat->run();
483     }
484     if (m_Permute2)
485     {
486         m_Permute2->run();
487     }
488 }
489 
490 arm_compute::Status
NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)491 NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
492                                                const TensorInfo& outputStateIn,
493                                                const TensorInfo& cellStateIn,
494                                                const TensorInfo& outputStateOut,
495                                                const TensorInfo& cellStateOut,
496                                                const TensorInfo& output,
497                                                const UnidirectionalSequenceLstmDescriptor& descriptor,
498                                                const LstmInputParamsInfo& paramsInfo)
499 {
500     TensorShape inputLayerShape = input.GetShape();
501     TensorShape outputLayerShape = output.GetShape();
502 
503     unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
504     unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
505     unsigned int inputSize = inputLayerShape[2];
506     unsigned int outputSize = outputLayerShape[2];
507 
508     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
509     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
510 
511     arm_compute::Status statusPermute1 = arm_compute::Status(arm_compute::ErrorCode::OK,
512                                                              "Permute1 status");
513     arm_compute::Status statusSplit = arm_compute::Status(arm_compute::ErrorCode::OK,
514                                                           "Split status");
515     arm_compute::Status statusLSTM = arm_compute::Status(arm_compute::ErrorCode::OK,
516                                                          "LSTM status");
517     arm_compute::Status statusConcat = arm_compute::Status(arm_compute::ErrorCode::OK,
518                                                            "Concat status");
519     arm_compute::Status statusPermute2 = arm_compute::Status(arm_compute::ErrorCode::OK,
520                                                              "Permute2 status");
521 
522     const arm_compute::TensorInfo aclInputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
523     const arm_compute::TensorInfo aclOutputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(output);
524 
525     //
526     // Permute validate
527     //
528     TensorInfo permuteOutInfo = TensorInfo(input);
529     arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
530     if (!descriptor.m_TimeMajor)
531     {
532         statusPermute1 =  arm_compute::NEPermute::validate(&aclInputInfo,
533                                                            &aclPermuteOutInfo,
534                                                            arm_compute::PermutationVector(0U, 2U, 1U));
535     }
536 
537     //
538     // Split and Concat Tensors validate
539     //
540     std::vector<arm_compute::TensorInfo> splitterOutputsTensorInfos;
541     std::vector<arm_compute::TensorInfo> concatInputsTensorInfos;
542     std::vector<arm_compute::ITensorInfo*> splitterOutputsTensorInfosPtr;
543     std::vector<const arm_compute::ITensorInfo*> concatInputsTensorInfosPtr;
544     splitterOutputsTensorInfos.reserve(maxTime);
545     concatInputsTensorInfos.reserve(maxTime);
546     for (unsigned int i = 0; i < maxTime; ++i)
547     {
548         arm_compute::TensorInfo splitter_out;
549         arm_compute::TensorInfo concat_in;
550 
551         auto splitterTensorInfo = TensorInfo(input);
552         auto concatTensorInfo   = TensorInfo(output);
553         splitterTensorInfo.SetShape({batchSize, inputSize});
554         concatTensorInfo.SetShape({batchSize, outputSize});
555 
556         arm_compute::TensorInfo aclSplitterTensorInfo
557             = armcomputetensorutils::BuildArmComputeTensorInfo(splitterTensorInfo);
558         arm_compute::TensorInfo aclConcatTensorInfo
559             = armcomputetensorutils::BuildArmComputeTensorInfo(concatTensorInfo);
560 
561         splitterOutputsTensorInfos.emplace_back(aclSplitterTensorInfo);
562         concatInputsTensorInfos.emplace_back(aclConcatTensorInfo);
563         splitterOutputsTensorInfosPtr.emplace_back(&splitterOutputsTensorInfos[i]);
564         concatInputsTensorInfosPtr.emplace_back(&concatInputsTensorInfos[i]);
565     }
566 
567     //
568     // Split validate
569     //
570     unsigned int numberDimensions = 3;
571     unsigned int dimension = 0; // splitting on 0-dimension (i.e. maxTime dimension)
572     unsigned int aclAxisSplit = CalcAclAxis(numberDimensions, dimension);
573 
574     if (maxTime != 1) // ACL split does not work with only one element to split.
575     {
576         if (!descriptor.m_TimeMajor)
577         {
578             statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo,
579                                                          splitterOutputsTensorInfosPtr,
580                                                          aclAxisSplit);
581         } else
582         {
583             statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit);
584         }
585     }
586 
587     //
588     // LSTM validate
589     //
590 
591     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
592 
593     const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
594 
595     lstm_params_info.set_cell_clip_params(descriptor.m_ClippingThresCell);
596     lstm_params_info.set_projection_clip_params(descriptor.m_ClippingThresProj);
597     // The inputs and outputs
598     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
599     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
600     const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
601     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
602     const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
603 
604     // Basic parameters
605     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
606                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
607     const arm_compute::TensorInfo aclInputToCellWeightsInfo
608                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
609     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
610                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
611     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
612                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
613     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
614                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
615     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
616                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
617     const arm_compute::TensorInfo aclForgetGateBiasInfo
618                                       = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
619     const arm_compute::TensorInfo aclCellBiasInfo
620                                       = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
621     const arm_compute::TensorInfo aclOutputGateBiasInfo
622                                       = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
623 
624     arm_compute::TensorInfo aclInputToInputWeightsInfo;
625     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
626     arm_compute::TensorInfo aclCellToInputWeightsInfo;
627     arm_compute::TensorInfo aclInputGateBiasInfo;
628     arm_compute::TensorInfo aclProjectionWeightsInfo;
629     arm_compute::TensorInfo aclProjectionBiasInfo;
630     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
631     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
632 
633     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
634     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
635     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
636     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
637 
638     if (!descriptor.m_CifgEnabled)
639     {
640         if (descriptor.m_PeepholeEnabled)
641         {
642             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
643         }
644         aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
645         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
646         aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
647 
648         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo,
649                                          &aclRecurrentToInputWeightsInfo,
650                                          descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
651                                          &aclInputGateBiasInfo);
652     }
653 
654     if (descriptor.m_ProjectionEnabled)
655     {
656         if (paramsInfo.m_ProjectionBias != nullptr)
657         {
658             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
659         }
660         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
661 
662         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
663                                                paramsInfo.m_ProjectionBias ? &aclProjectionBiasInfo : nullptr);
664     }
665 
666     if (descriptor.m_PeepholeEnabled)
667     {
668         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
669         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
670 
671         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
672     }
673 
674     if (descriptor.m_LayerNormEnabled)
675     {
676         if (!descriptor.m_CifgEnabled)
677         {
678             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
679         }
680         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
681         aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
682         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
683 
684         lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? nullptr :
685                                                         &aclInputLayerNormWeightsInfo,
686                                                         &aclForgetLayerNormWeightsInfo,
687                                                         &aclCellLayerNormWeightsInfo,
688                                                         &aclOutputLayerNormWeightsInfo);
689     }
690 
691     lstm_params_info.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
692                                              descriptor.m_ForgetIntermediateScale,
693                                              descriptor.m_CellIntermediateScale,
694                                              descriptor.m_OutputIntermediateScale);
695 
696     lstm_params_info.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
697 
698     for (unsigned int i = 0; i != maxTime; ++i)
699     {
700 
701         // Set LSTM input and output ITensors depending on:
702         // input format (timeMajor) & number of LSTM batches (maxTime).
703         arm_compute::ITensorInfo* outputLSTM;
704         arm_compute::ITensorInfo* inputLSTM;
705 
706         // If there is only one LSTM time major batch, we will not concat OR permute.
707         // Set input of LSTM to be first input ITensor.
708         // Set output of LSTM to be final output ITensor.
709         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
710         if (maxTime == 1 && !descriptor.m_TimeMajor)
711         {
712             TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
713             TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);
714 
715             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
716             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
717 
718             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
719             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
720 
721             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(acl_input_shape_shrink);
722             inputLSTM = const_cast<arm_compute::TensorInfo*>(&aclInputInfo);
723 
724             const_cast<arm_compute::TensorInfo*>(&aclOutputInfo)->set_tensor_shape(acl_output_shape_shrink);
725             outputLSTM = const_cast<arm_compute::TensorInfo*>(&aclOutputInfo);
726         }
727         // If there is only one LSTM batch major batch, we will not concat, only permute.
728         // Set input of LSTM to be output of initial permute.
729         // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
730         // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
731         else if (maxTime == 1 && !descriptor.m_TimeMajor)
732         {
733             TensorShape inputShape = GetTensorShape(aclPermuteOutInfo.tensor_shape(), 1U);
734             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
735             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
736             aclPermuteOutInfo.set_tensor_shape(acl_input_shape_shrink);
737             inputLSTM = &aclPermuteOutInfo;
738 
739             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
740         }
741         // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
742         else
743         {
744             inputLSTM = splitterOutputsTensorInfosPtr[i];
745             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
746         }
747 
748         statusLSTM = arm_compute::NEQLSTMLayer::validate(inputLSTM,
749                                                          &aclInputToForgetWeightsInfo,
750                                                          &aclInputToCellWeightsInfo,
751                                                          &aclInputToOutputWeightsInfo,
752                                                          &aclRecurrentToForgetWeightsInfo,
753                                                          &aclRecurrentToCellWeightsInfo,
754                                                          &aclRecurrentToOutputWeightsInfo,
755                                                          &aclForgetGateBiasInfo,
756                                                          &aclCellBiasInfo,
757                                                          &aclOutputGateBiasInfo,
758                                                          &aclCellStateInInfo,
759                                                          &aclOutputStateInInfo,
760                                                          &aclCellStateOutInfo,
761                                                          &aclOutputStateOutInfo,
762                                                          outputLSTM,
763                                                          lstm_params_info);
764     }
765 
766     //
767     // Concat validate
768     //
769 
770     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
771     TensorShape shape = GetTensorShape(concatInputsTensorInfosPtr[0]->tensor_shape(), 1U);
772     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
773     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
774 
775     TensorInfo concatOutputTensorInfo = TensorInfo(output);
776     concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
777     arm_compute::TensorInfo aclConcatOutputTensorInfo= BuildArmComputeTensorInfo(concatOutputTensorInfo);
778 
779     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
780     {
781         for (unsigned int i = 0; i < maxTime; ++i)
782         {
783             auto acl_shape_expand = BuildArmComputeTensorShape(shapeExpandTimeMajor);
784             concatInputsTensorInfos[i].set_tensor_shape(acl_shape_expand);
785         }
786 
787         unsigned int aclAxisConcat = CalcAclAxis(numberDimensions, dimension);
788         if (!descriptor.m_TimeMajor)
789         {
790             statusConcat = arm_compute::NEConcatenateLayer::validate(concatInputsTensorInfosPtr,
791                                                                      &aclConcatOutputTensorInfo,
792                                                                      aclAxisConcat);
793         }
794         else
795         {
796             statusConcat = arm_compute::NEConcatenateLayer::validate(concatInputsTensorInfosPtr,
797                                                                      &aclOutputInfo,
798                                                                      aclAxisConcat);
799         }
800     }
801     // If only one LSTM batch, we do not concat and/or permute.
802     // Must ensure final output info is expanded to correct batch major dimensions.
803     else
804     {
805         if (!descriptor.m_TimeMajor)
806         {
807             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
808                 BuildArmComputeTensorShape(shapeExpandBatchMajor));
809         }
810         else
811         {
812             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
813                 BuildArmComputeTensorShape(shapeExpandTimeMajor));
814         }
815     }
816 
817     //
818     // Permute validate
819     //
820     if (!descriptor.m_TimeMajor)
821     {
822         // Output now time major. Permute output back to batch major.
823         if (maxTime != 1)
824         {
825             statusPermute2 = arm_compute::NEPermute::validate(&aclConcatOutputTensorInfo,
826                                                               &aclOutputInfo,
827                                                               arm_compute::PermutationVector(0U, 2U, 1U));
828         }
829         else
830         {
831             statusPermute2 = arm_compute::NEPermute::validate(concatInputsTensorInfosPtr[0],
832                                                               &aclOutputInfo,
833                                                               arm_compute::PermutationVector(0U, 2U, 1U));
834         }
835     }
836 
837     auto okCode = arm_compute::ErrorCode::OK;
838     if (statusPermute1.error_code() == okCode &&
839         statusSplit.error_code()    == okCode &&
840         statusLSTM .error_code()    == okCode &&
841         statusConcat.error_code()   == okCode &&
842         statusPermute2.error_code() == okCode)
843     {
844         return arm_compute::Status(arm_compute::ErrorCode::OK,
845                                    "All Unidirectional Sequence LSTM layer validate status OK.");
846     }
847     else
848     {
849         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
850                                    "Unidirectional Sequence LSTM layer validate status failed.");
851     }
852 }
853 
FreeUnusedTensors()854 void NeonUnidirectionalSequenceLstmWorkload::FreeUnusedTensors()
855 {
856     FreeTensorIfUnused(m_InputToInputWeightsTensor);
857     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
858     FreeTensorIfUnused(m_InputToCellWeightsTensor);
859     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
860     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
861     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
862     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
863     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
864     FreeTensorIfUnused(m_CellToInputWeightsTensor);
865     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
866     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
867     FreeTensorIfUnused(m_InputGateBiasTensor);
868     FreeTensorIfUnused(m_ForgetGateBiasTensor);
869     FreeTensorIfUnused(m_CellBiasTensor);
870     FreeTensorIfUnused(m_OutputGateBiasTensor);
871     FreeTensorIfUnused(m_ProjectionWeightsTensor);
872     FreeTensorIfUnused(m_ProjectionBiasTensor);
873     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
874     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
875     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
876     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
877 }
878 
879 } //namespace armnn
880