1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/QuantizedLstmParams.hpp> 9 #include <backendsCommon/Workload.hpp> 10 #include <backendsCommon/WorkloadData.hpp> 11 12 #include <arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h> 13 14 namespace armnn 15 { 16 17 arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn, 18 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut, 19 const TensorInfo& output, 20 const QuantizedLstmInputParamsInfo& paramsInfo); 21 22 class ClQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor> 23 { 24 public: 25 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info); 26 void Execute() const override; 27 28 private: 29 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer; 30 31 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor; 32 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor; 33 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor; 34 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor; 35 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor; 36 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor; 37 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor; 38 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor; 39 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor; 40 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor; 41 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor; 42 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor; 43 44 void FreeUnusedTensors(); 45 }; 46 47 } //namespace armnn 48 49 50