• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <backendsCommon/Workload.hpp>
11 #include <backendsCommon/WorkloadData.hpp>
12 
13 #include <arm_compute/runtime/CL/functions/CLLSTMLayer.h>
14 
15 namespace armnn
16 {
17 
18 class ClLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor>
19 {
20 public:
21     ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
22     void Execute() const override;
23 
24 private:
25     mutable arm_compute::CLLSTMLayer m_LstmLayer;
26 
27     std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
28     std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
29     std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
30     std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
31     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
32     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
33     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
34     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
35     std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
36     std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
37     std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
38     std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
39     std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
40     std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
41     std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
42     std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
43     std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
44     std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor;
45     std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor;
46     std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor;
47     std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor;
48 
49     std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
50 
51     void FreeUnusedTensors();
52 };
53 
54 arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
55                                                 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
56                                                 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
57                                                 const TensorInfo& output, const LstmDescriptor &descriptor,
58                                                 const LstmInputParamsInfo& paramsInfo);
59 } //namespace armnn
60