• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonQLstmWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8 
9 #include "aclCommon/ArmComputeTensorUtils.hpp"
10 
11 #include "neon/NeonTensorHandle.hpp"
12 
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16 
NeonQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info)17 NeonQLstmWorkload::NeonQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info)
18         : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
19 {
20     arm_compute::LSTMParams<arm_compute::ITensor> qLstmParams;
21 
22     // Mandatory params
23     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
24     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
25 
26     m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
27     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
28 
29     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
30     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
31 
32     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
33     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
34 
35     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
36     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
37 
38     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
39     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
40 
41     m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
42     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
43 
44     m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
45     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
46 
47     m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
48     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
49 
50     // Create tensors for optional params if they are enabled
51     if (m_Data.m_Parameters.m_PeepholeEnabled)
52     {
53         m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
54 
55         if (!m_Data.m_Parameters.m_CifgEnabled)
56         {
57             // In ACL this is categorised as a CIFG param and not a Peephole param
58             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
59         }
60 
61         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
62         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
63 
64         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
65         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
66 
67         // Set Peephole params
68         qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
69                                         m_CellToOutputWeightsTensor.get());
70     }
71 
72     if (m_Data.m_Parameters.m_ProjectionEnabled)
73     {
74         m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
75         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
76 
77         m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
78         if (m_Data.m_ProjectionBias != nullptr)
79         {
80             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
81         }
82 
83         // Set projection params
84         qLstmParams.set_projection_params(
85             m_ProjectionWeightsTensor.get(),
86             m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
87     }
88 
89     if (m_Data.m_Parameters.m_LayerNormEnabled)
90     {
91         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
92 
93         if (!m_Data.m_Parameters.m_CifgEnabled)
94         {
95             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
96         }
97 
98         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
99         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
100 
101         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
102         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
103 
104         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
105         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
106 
107         // Set layer norm params
108         qLstmParams.set_layer_normalization_params(
109             m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
110             m_ForgetLayerNormWeightsTensor.get(),
111             m_CellLayerNormWeightsTensor.get(),
112             m_OutputLayerNormWeightsTensor.get());
113     }
114 
115     if (!m_Data.m_Parameters.m_CifgEnabled)
116     {
117         m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
118         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
119 
120         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
121         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
122 
123         m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
124         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
125 
126         // Set CIFG params
127         qLstmParams.set_cifg_params(
128             m_InputToInputWeightsTensor.get(),
129             m_RecurrentToInputWeightsTensor.get(),
130             m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
131             m_InputGateBiasTensor.get());
132     }
133 
134     // Input/Output tensors
135     const arm_compute::ITensor& input         = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
136     arm_compute::ITensor& outputStateIn = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
137     const arm_compute::ITensor& cellStateIn   = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
138 
139     arm_compute::ITensor& outputStateOut = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
140     arm_compute::ITensor& cellStateOut   = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
141     arm_compute::ITensor& output         = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
142 
143     // Set scalar descriptor params
144     qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
145     qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
146     qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
147                                         m_Data.m_Parameters.m_HiddenStateScale);
148     qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
149                                         m_Data.m_Parameters.m_ForgetIntermediateScale,
150                                         m_Data.m_Parameters.m_CellIntermediateScale,
151                                         m_Data.m_Parameters.m_OutputIntermediateScale);
152 
153     // QLSTM NEON configure
154     m_QLstmLayer.configure(&input,
155                            m_InputToForgetWeightsTensor.get(),
156                            m_InputToCellWeightsTensor.get(),
157                            m_InputToOutputWeightsTensor.get(),
158                            m_RecurrentToForgetWeightsTensor.get(),
159                            m_RecurrentToCellWeightsTensor.get(),
160                            m_RecurrentToOutputWeightsTensor.get(),
161                            m_ForgetGateBiasTensor.get(),
162                            m_CellBiasTensor.get(),
163                            m_OutputGateBiasTensor.get(),
164                            &cellStateIn,
165                            &outputStateIn,
166                            &cellStateOut,
167                            &outputStateOut,
168                            &output,
169                            qLstmParams);
170 
171     // Initialise ACL tensor data for mandatory params
172     InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
173     InitializeArmComputeTensorData(*m_InputToCellWeightsTensor,   m_Data.m_InputToCellWeights);
174     InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
175 
176     InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
177     InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor,   m_Data.m_RecurrentToCellWeights);
178     InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
179 
180     InitializeArmComputeTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
181     InitializeArmComputeTensorData(*m_CellBiasTensor,       m_Data.m_CellBias);
182     InitializeArmComputeTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
183 
184     // Initialise ACL tensor data for optional params
185     if (!m_Data.m_Parameters.m_CifgEnabled)
186     {
187         InitializeArmComputeTensorData(*m_InputToInputWeightsTensor,     m_Data.m_InputToInputWeights);
188         InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
189         InitializeArmComputeTensorData(*m_InputGateBiasTensor,           m_Data.m_InputGateBias);
190     }
191 
192     if (m_Data.m_Parameters.m_ProjectionEnabled)
193     {
194         InitializeArmComputeTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
195 
196         if (m_Data.m_ProjectionBias != nullptr)
197         {
198             InitializeArmComputeTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
199         }
200     }
201 
202     if (m_Data.m_Parameters.m_PeepholeEnabled)
203     {
204         if (!m_Data.m_Parameters.m_CifgEnabled)
205         {
206             InitializeArmComputeTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
207         }
208 
209         InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
210         InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
211     }
212 
213     if (m_Data.m_Parameters.m_LayerNormEnabled)
214     {
215         if (!m_Data.m_Parameters.m_CifgEnabled)
216         {
217             InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
218         }
219 
220         InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
221         InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor,   m_Data.m_CellLayerNormWeights);
222         InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
223     }
224 
225     // QLSTM NEON prepare
226     m_QLstmLayer.prepare();
227 
228     FreeUnusedTensors();
229 }
230 
Execute() const231 void NeonQLstmWorkload::Execute() const
232 {
233     m_QLstmLayer.run();
234 }
235 
NeonQLstmWorkloadValidate(const TensorInfo & input,const TensorInfo & cellStateIn,const TensorInfo & outputStateIn,const TensorInfo & cellStateOut,const TensorInfo & outputStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)236 arm_compute::Status NeonQLstmWorkloadValidate(const TensorInfo& input,
237                                               const TensorInfo& cellStateIn,
238                                               const TensorInfo& outputStateIn,
239                                               const TensorInfo& cellStateOut,
240                                               const TensorInfo& outputStateOut,
241                                               const TensorInfo& output,
242                                               const QLstmDescriptor& descriptor,
243                                               const LstmInputParamsInfo& paramsInfo)
244 {
245     arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
246 
247     // Input/Output tensor info
248     const arm_compute::TensorInfo aclInputInfo         = BuildArmComputeTensorInfo(input);
249     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
250     const arm_compute::TensorInfo aclCellStateInInfo   = BuildArmComputeTensorInfo(cellStateIn);
251 
252     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
253     const arm_compute::TensorInfo aclCellStateOutInfo   = BuildArmComputeTensorInfo(cellStateOut);
254     const arm_compute::TensorInfo aclOutputInfo         = BuildArmComputeTensorInfo(output);
255 
256     // Mandatory tensor info
257     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
258             = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
259     const arm_compute::TensorInfo aclInputToCellWeightsInfo
260             = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
261     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
262             = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
263     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
264             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
265     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
266             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
267     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
268             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
269     const arm_compute::TensorInfo aclForgetGateBiasInfo
270             = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
271     const arm_compute::TensorInfo aclCellBiasInfo
272             = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
273     const arm_compute::TensorInfo aclOutputGateBiasInfo
274             = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
275 
276     // Optional tensor info
277     arm_compute::TensorInfo aclInputToInputWeightsInfo;
278     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
279 
280     arm_compute::TensorInfo aclCellToInputWeightsInfo;
281     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
282     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
283 
284     arm_compute::TensorInfo aclInputGateBiasInfo;
285 
286     arm_compute::TensorInfo aclProjectionWeightsInfo;
287     arm_compute::TensorInfo aclProjectionBiasInfo;
288 
289     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
290     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
291     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
292     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
293 
294     // Create tensor info for optional params if they are enabled
295     if (descriptor.m_PeepholeEnabled)
296     {
297         if (!descriptor.m_CifgEnabled)
298         {
299             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
300         }
301 
302         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
303         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
304 
305         // Set peephole params info
306         aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
307                                           &aclCellToOutputWeightsInfo);
308     }
309 
310     if (descriptor.m_ProjectionEnabled)
311     {
312         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
313 
314         if (paramsInfo.m_ProjectionBias != nullptr)
315         {
316             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
317         }
318 
319         // Set projection params info
320         aclParamsInfo.set_projection_params(
321             &aclProjectionWeightsInfo,
322             paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
323     }
324 
325     if (descriptor.m_LayerNormEnabled)
326     {
327         if (!descriptor.m_CifgEnabled)
328         {
329             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
330         }
331 
332         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
333         aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
334         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
335 
336         // Set layer norm params info
337         aclParamsInfo.set_layer_normalization_params(
338             paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
339             &aclForgetLayerNormWeightsInfo,
340             &aclCellLayerNormWeightsInfo,
341             &aclOutputLayerNormWeightsInfo);
342     }
343 
344     if (!descriptor.m_CifgEnabled)
345     {
346         aclInputToInputWeightsInfo     = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
347         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
348         aclInputGateBiasInfo           = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
349 
350         // Set CIFG params info
351         aclParamsInfo.set_cifg_params(
352             &aclInputToInputWeightsInfo,
353             &aclRecurrentToInputWeightsInfo,
354             paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
355             &aclInputGateBiasInfo);
356     }
357 
358     // Set scalar descriptor params
359     aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
360     aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
361     aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
362     aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
363                                           descriptor.m_ForgetIntermediateScale,
364                                           descriptor.m_CellIntermediateScale,
365                                           descriptor.m_OutputIntermediateScale);
366 
367     // QLSTM NEON validate
368     return arm_compute::NEQLSTMLayer::validate(&aclInputInfo,
369                                                &aclInputToForgetWeightsInfo,
370                                                &aclInputToCellWeightsInfo,
371                                                &aclInputToOutputWeightsInfo,
372                                                &aclRecurrentToForgetWeightsInfo,
373                                                &aclRecurrentToCellWeightsInfo,
374                                                &aclRecurrentToOutputWeightsInfo,
375                                                &aclForgetGateBiasInfo,
376                                                &aclCellBiasInfo,
377                                                &aclOutputGateBiasInfo,
378                                                &aclCellStateInInfo,
379                                                &aclOutputStateInInfo,
380                                                &aclCellStateOutInfo,
381                                                &aclOutputStateOutInfo,
382                                                &aclOutputInfo,
383                                                aclParamsInfo);
384 }
385 
FreeUnusedTensors()386 void NeonQLstmWorkload::FreeUnusedTensors()
387 {
388     FreeTensorIfUnused(m_InputToInputWeightsTensor);
389     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
390     FreeTensorIfUnused(m_InputToCellWeightsTensor);
391     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
392 
393     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
394     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
395     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
396     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
397 
398     FreeTensorIfUnused(m_CellToInputWeightsTensor);
399     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
400     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
401 
402     FreeTensorIfUnused(m_InputGateBiasTensor);
403     FreeTensorIfUnused(m_ForgetGateBiasTensor);
404     FreeTensorIfUnused(m_CellBiasTensor);
405     FreeTensorIfUnused(m_OutputGateBiasTensor);
406 
407     FreeTensorIfUnused(m_ProjectionWeightsTensor);
408     FreeTensorIfUnused(m_ProjectionBiasTensor);
409 
410     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
411     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
412     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
413     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
414 }
415 
416 } //namespace armnn