1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <backendsCommon/Workload.hpp> 11 #include <backendsCommon/WorkloadData.hpp> 12 13 #include <arm_compute/runtime/CL/functions/CLLSTMLayer.h> 14 15 namespace armnn 16 { 17 18 class ClLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor> 19 { 20 public: 21 ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); 22 void Execute() const override; 23 24 private: 25 mutable arm_compute::CLLSTMLayer m_LstmLayer; 26 27 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor; 28 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor; 29 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor; 30 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor; 31 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor; 32 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor; 33 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor; 34 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor; 35 std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor; 36 std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor; 37 std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor; 38 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor; 39 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor; 40 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor; 41 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor; 42 std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor; 43 std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor; 44 std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor; 45 std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor; 46 std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor; 47 std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor; 48 49 std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer; 50 51 void FreeUnusedTensors(); 52 }; 53 54 arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn, 55 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, 56 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, 57 const TensorInfo& output, const LstmDescriptor &descriptor, 58 const LstmInputParamsInfo& paramsInfo); 59 } //namespace armnn 60